# Simple NNs with custom embeddings

## Importing data

In [7]:
import copy
from itertools import product
from numpy import isnan
import pandas as pd
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
import torchtext.data as data
import sys
import warnings
warnings.filterwarnings('ignore')

sys.path.append('../data_pipeline/')
import preprocessing as pre
from training import TrainingModule

SEED = 1312
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [8]:
train_data, test_data, val_data, TEXT, LABEL = pre.get_data(
    'train_small.csv', 'val_small.csv', 'test_small.csv')

Connected!


In [3]:
USE_CUDA = torch.cuda.is_available()

TEXT.build_vocab(train_data)
LABEL.build_vocab(train_data)

BATCH_SIZE = 5

train_it, test_it, val_it = data.BucketIterator.splits(
    (train_data, test_data, val_data), 
    batch_size = BATCH_SIZE,
    sort_key=lambda x: len(x.alj_text),
    sort_within_batch=True,
    device = torch.device('cuda' if USE_CUDA else 'cpu'))

## Simple NN Model

In [4]:
class WordEmbAvg(nn.Module):
    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim, pad_idx, two_layers=True, dropout_p=0.0):
        
        super().__init__()
        
        # Define an embedding layer, a couple of linear layers, and 
        # the ReLU non-linearity.

        ##YOUR CODE HERE##
        self.embedding = nn.Embedding(input_dim, embedding_dim)
        if two_layers == True:
            self.linear1 = nn.Linear(embedding_dim, hidden_dim)
            self.linear2 = nn.Linear(hidden_dim, output_dim) 
        else:
            self.linear1 = nn.Linear(embedding_dim, output_dim)
            self.linear2 = None
        self.relu = nn.ReLU()
        self.drop_layer = nn.Dropout(p=dropout_p)

        
        
    def forward(self, text):

        ##YOUR CODE HERE##
        embedded = self.embedding(text)
        embedded = embedded.mean(0)
        if not self.linear2:
            linear1_output = self.linear1(embedded)
            output = self.relu(linear1_output)
            output = self.drop_layer(output)
            return output
        else:
            linear1_output = self.linear1(embedded)
            linear2_input = self.relu(linear1_output)
            output = self.linear2(linear2_input)
            output = self.drop_layer(output)
            return output

## Training models

In [5]:
# Store training results

df = pd.DataFrame(columns=['architecture', 'embeddings',
                           'hidden', 'dropouts',
                           'learning_rate', 'epochs',
                           'dev_acc', 'dev_prec', 'dev_recall',
                           'metric'])

# Model architecture parameters
INPUT_DIM = len(TEXT.vocab)
EMBEDDING_SIZES = [32, 64, 128, 256]
HIDDEN_SIZES = [10, 25, 40, 50]
OUTPUT_SIZE = 1
DROPOUTS = [0, 0.1, 0.25, 0.5, 0.75]
PADDING_IDX = TEXT.vocab.stoi[TEXT.pad_token]

# Model training hyperparameters
LEARNING_RATE = [0.01, 0.001, 0.0001]
train_len = 0
train_pos = 0
for batch in train_it:
    train_len += len(batch.decision_binary)
    train_pos += batch.decision_binary.sum().item()
POS_WEIGHT = torch.tensor([(train_len - train_pos) / train_pos])
if USE_CUDA:
    POS_WEIGHT = POS_WEIGHT.cuda()
EPOCHS = 10

# Iterator over various model parameters
param_iter = product (EMBEDDING_SIZES, HIDDEN_SIZES,
                      DROPOUTS, LEARNING_RATE)

# Magic loop
best_acc = (None, None)
best_rec = (None, None)
best_prec = (None, None)
for i, (embed_size, hidden_size, dropout, lr) in enumerate(param_iter):
    print(f'Architecture #{i}\n' + '-' * 20)
    model = WordEmbAvg(INPUT_DIM, embed_size, hidden_size,
                OUTPUT_SIZE, PADDING_IDX, dropout_p=dropout)
    
    tm = TrainingModule(model, lr, POS_WEIGHT, USE_CUDA, EPOCHS)
    
    best_models = tm.train_model(train_it, val_it)
    
    for metric, best_model in best_models.items():
        row = [i, embed_size, hidden_size, dropout,
               lr, EPOCHS, best_model.accuracy,
               best_model.precision, best_model.recall, metric]
        df.loc[len(df)] = row
        if best_acc[0] is None or best_model.accuracy > best_acc[1]:
            best_acc = (copy.deepcopy(best_model.model), best_model.accuracy)
        if best_rec[0] is None or best_model.recall > best_rec[1]:
            best_rec = (copy.deepcopy(best_model.model), best_model.recall)
        if best_prec[0] is None or isnan(best_prec[1]) or\
          best_model.precision > best_prec[1]:
            best_prec = (copy.deepcopy(best_model.model), best_model.precision)
    
    print('-' * 20 + '\n')


Architecture #0
--------------------
Epoch 0: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8839
Epoch 1: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8762
Epoch 2: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8557
Epoch 3: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8205
Epoch 4: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.7942
Epoch 5: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.7450
Epoch 6: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.6790
Epoch 7: Dev Accuracy: 1.0000; Dev Precision: 1.0000; Dev Recall: 1.0000; Dev Loss:0.6361
Epoch 8: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.6018
Epoch 9: Dev Accuracy: 1.0000; Dev Precision: 1.0000; Dev Recall: 1.0000; Dev Loss:0.5483
--------------------

Architecture #1
--------------------
Epoch 0: Dev Accuracy: 0.750

Epoch 9: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8820
--------------------

Architecture #9
--------------------
Epoch 0: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.8955
Epoch 1: Dev Accuracy: 0.5000; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.8817
Epoch 2: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8733
Epoch 3: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8753
Epoch 4: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8611
Epoch 5: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8530
Epoch 6: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8349
Epoch 7: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8281
Epoch 8: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8288
Epoch 9: Dev Accuracy: 0.5000; Dev Precision: 0.00

Epoch 6: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.8927
Epoch 7: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.8925
Epoch 8: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.8924
Epoch 9: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.8923
--------------------

Architecture #18
--------------------
Epoch 0: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.9093
Epoch 1: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8750
Epoch 2: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8668
Epoch 3: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8473
Epoch 4: Dev Accuracy: 0.5000; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.8087
Epoch 5: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.7449
Epoch 6: Dev Accuracy: 0.7500; Dev Precision: 0.5000

Epoch 5: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.9002
Epoch 6: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8997
Epoch 7: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8995
Epoch 8: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8995
Epoch 9: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8994
--------------------

Architecture #27
--------------------
Epoch 0: Dev Accuracy: 0.2500; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.9021
Epoch 1: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.9305
Epoch 2: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.9156
Epoch 3: Dev Accuracy: 0.5000; Dev Precision: 0.3333; Dev Recall: 1.0000; Dev Loss:0.8872
Epoch 4: Dev Accuracy: 0.5000; Dev Precision: 0.3333; Dev Recall: 1.0000; Dev Loss:0.8860
Epoch 5: Dev Accuracy: 0.5000; Dev Precision: 0.3333; Dev Rec

Epoch 3: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9082
Epoch 4: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9081
Epoch 5: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9080
Epoch 6: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9080
Epoch 7: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9081
Epoch 8: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9081
Epoch 9: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9080
--------------------

Architecture #36
--------------------
Epoch 0: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8894
Epoch 1: Dev Accuracy: 0.5000; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.8820
Epoch 2: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8787
Epoch 3: Dev Accuracy: 0.7500; Dev Precisio

Epoch 0: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8951
Epoch 1: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8951
Epoch 2: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8945
Epoch 3: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8940
Epoch 4: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8934
Epoch 5: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8927
Epoch 6: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8923
Epoch 7: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8918
Epoch 8: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8912
Epoch 9: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8903
--------------------

Architecture #45
--------------------
Epoch 0: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000;

Epoch 9: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8812
--------------------

Architecture #53
--------------------
Epoch 0: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8791
Epoch 1: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8791
Epoch 2: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8788
Epoch 3: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8783
Epoch 4: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8782
Epoch 5: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8779
Epoch 6: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8777
Epoch 7: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8775
Epoch 8: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8772
Epoch 9: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000;

Epoch 7: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8613
Epoch 8: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8591
Epoch 9: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8560
--------------------

Architecture #62
--------------------
Epoch 0: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9377
Epoch 1: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9373
Epoch 2: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9367
Epoch 3: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9359
Epoch 4: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9352
Epoch 5: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9346
Epoch 6: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9342
Epoch 7: Dev Accuracy: 0.2500; Dev Precision: 0.2

Epoch 3: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8657
Epoch 4: Dev Accuracy: 0.5000; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.8638
Epoch 5: Dev Accuracy: 0.5000; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.8627
Epoch 6: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8624
Epoch 7: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8624
Epoch 8: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8633
Epoch 9: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8639
--------------------

Architecture #71
--------------------
Epoch 0: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8809
Epoch 1: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8809
Epoch 2: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8801
Epoch 3: Dev Accuracy: 0.7500; Dev Precision: nan; D

Epoch 3: Dev Accuracy: 0.5000; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.8915
Epoch 4: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8902
Epoch 5: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8877
Epoch 6: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8841
Epoch 7: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8800
Epoch 8: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8756
Epoch 9: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8734
--------------------

Architecture #80
--------------------
Epoch 0: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9029
Epoch 1: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9025
Epoch 2: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9023
Epoch 3: Dev Accuracy: 0.2500; Dev Preci

Epoch 4: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8638
Epoch 5: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8626
Epoch 6: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8631
Epoch 7: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8640
Epoch 8: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8637
Epoch 9: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8640
--------------------

Architecture #89
--------------------
Epoch 0: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8936
Epoch 1: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8932
Epoch 2: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8928
Epoch 3: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8921
Epoch 4: Dev Accuracy: 0.7500; Dev Precision: nan; Dev 

Epoch 4: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8657
Epoch 5: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8614
Epoch 6: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8558
Epoch 7: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8516
Epoch 8: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8497
Epoch 9: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8451
--------------------

Architecture #98
--------------------
Epoch 0: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8885
Epoch 1: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8884
Epoch 2: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8884
Epoch 3: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8880
Epoch 4: Dev Accuracy: 0.7500; Dev Precision: nan; Dev 

Epoch 4: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.8968
Epoch 5: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.8948
Epoch 6: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.8922
Epoch 7: Dev Accuracy: 0.5000; Dev Precision: 0.3333; Dev Recall: 1.0000; Dev Loss:0.8892
Epoch 8: Dev Accuracy: 0.5000; Dev Precision: 0.3333; Dev Recall: 1.0000; Dev Loss:0.8871
Epoch 9: Dev Accuracy: 0.5000; Dev Precision: 0.3333; Dev Recall: 1.0000; Dev Loss:0.8843
--------------------

Architecture #107
--------------------
Epoch 0: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.9020
Epoch 1: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.9013
Epoch 2: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.9009
Epoch 3: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.9006
Epoch 4: Dev Accuracy: 0.7500; Dev Precision: nan; 

Epoch 4: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8833
Epoch 5: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8816
Epoch 6: Dev Accuracy: 0.5000; Dev Precision: 0.3333; Dev Recall: 1.0000; Dev Loss:0.8810
Epoch 7: Dev Accuracy: 0.5000; Dev Precision: 0.3333; Dev Recall: 1.0000; Dev Loss:0.8805
Epoch 8: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8772
Epoch 9: Dev Accuracy: 1.0000; Dev Precision: 1.0000; Dev Recall: 1.0000; Dev Loss:0.8747
--------------------

Architecture #116
--------------------
Epoch 0: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8864
Epoch 1: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8859
Epoch 2: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8858
Epoch 3: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8855
Epoch 4: Dev Accuracy: 0.7500; Dev Precision: nan; 

Epoch 3: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9223
Epoch 4: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9214
Epoch 5: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9205
Epoch 6: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9201
Epoch 7: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9196
Epoch 8: Dev Accuracy: 0.5000; Dev Precision: 0.3333; Dev Recall: 1.0000; Dev Loss:0.9189
Epoch 9: Dev Accuracy: 0.2500; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.9192
--------------------

Architecture #125
--------------------
Epoch 0: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8928
Epoch 1: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8925
Epoch 2: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8923
Epoch 3: Dev Accuracy: 0.7500; Dev Precision: na

Epoch 1: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9155
Epoch 2: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9117
Epoch 3: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9131
Epoch 4: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9117
Epoch 5: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9089
Epoch 6: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9070
Epoch 7: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9058
Epoch 8: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9061
Epoch 9: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9072
--------------------

Architecture #134
--------------------
Epoch 0: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9064
Epoch 1: Dev Accuracy: 0.2500; Dev Prec

Epoch 0: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8922
Epoch 1: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8780
Epoch 2: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8735
Epoch 3: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8703
Epoch 4: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8720
Epoch 5: Dev Accuracy: 0.5000; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.8698
Epoch 6: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8693
Epoch 7: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8689
Epoch 8: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8683
Epoch 9: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8689
--------------------

Architecture #143
--------------------
Epoch 0: Dev Accuracy: 0.7500; Dev Precision: nan; Dev

Epoch 0: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.8938
Epoch 1: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.8904
Epoch 2: Dev Accuracy: 0.5000; Dev Precision: 0.3333; Dev Recall: 1.0000; Dev Loss:0.8890
Epoch 3: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8865
Epoch 4: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8842
Epoch 5: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8831
Epoch 6: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8795
Epoch 7: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8768
Epoch 8: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8711
Epoch 9: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8652
--------------------

Architecture #152
--------------------
Epoch 0: Dev Accuracy: 0.7500; Dev Prec

Epoch 0: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8728
Epoch 1: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8733
Epoch 2: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8853
Epoch 3: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8914
Epoch 4: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8785
Epoch 5: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8851
Epoch 6: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8814
Epoch 7: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8750
Epoch 8: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8688
Epoch 9: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8677
--------------------

Architecture #161
--------------------
Epoch 0: Dev Accuracy: 0.7500; Dev Precisi

Epoch 0: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8828
Epoch 1: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8773
Epoch 2: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8754
Epoch 3: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8763
Epoch 4: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8769
Epoch 5: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8773
Epoch 6: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8757
Epoch 7: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8750
Epoch 8: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8746
Epoch 9: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8726
--------------------

Architecture #170
--------------------
Epoch 0: Dev Accuracy: 0.2500; Dev Precision:

Epoch 0: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.9273
Epoch 1: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.9162
Epoch 2: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8964
Epoch 3: Dev Accuracy: 0.5000; Dev Precision: 0.3333; Dev Recall: 1.0000; Dev Loss:0.8949
Epoch 4: Dev Accuracy: 0.5000; Dev Precision: 0.3333; Dev Recall: 1.0000; Dev Loss:0.8949
Epoch 5: Dev Accuracy: 0.5000; Dev Precision: 0.3333; Dev Recall: 1.0000; Dev Loss:0.8929
Epoch 6: Dev Accuracy: 0.5000; Dev Precision: 0.3333; Dev Recall: 1.0000; Dev Loss:0.8914
Epoch 7: Dev Accuracy: 0.5000; Dev Precision: 0.3333; Dev Recall: 1.0000; Dev Loss:0.8891
Epoch 8: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8875
Epoch 9: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8866
--------------------

Architecture #179
--------------------
Epoch 0: Dev Accuracy: 0.2500; Dev Precision: 0.

Epoch 0: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8814
Epoch 1: Dev Accuracy: 0.5000; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.8791
Epoch 2: Dev Accuracy: 0.5000; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.8782
Epoch 3: Dev Accuracy: 0.5000; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.8767
Epoch 4: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8733
Epoch 5: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8721
Epoch 6: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8684
Epoch 7: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8645
Epoch 8: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8611
Epoch 9: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8564
--------------------

Architecture #188
--------------------
Epoch 0: Dev Accuracy: 0.2500; Dev Precisi

Epoch 7: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.7280
Epoch 8: Dev Accuracy: 0.5000; Dev Precision: 0.3333; Dev Recall: 1.0000; Dev Loss:0.7352
Epoch 9: Dev Accuracy: 0.5000; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:1.2348
--------------------

Architecture #196
--------------------
Epoch 0: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8814
Epoch 1: Dev Accuracy: 0.5000; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.8718
Epoch 2: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8698
Epoch 3: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8703
Epoch 4: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8712
Epoch 5: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8658
Epoch 6: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8646
Epoch 7: Dev Accuracy: 0.7500; Dev Precision:

Epoch 3: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.7723
Epoch 4: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8040
Epoch 5: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.7374
Epoch 6: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.6707
Epoch 7: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.6460
Epoch 8: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.5997
Epoch 9: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.6460
--------------------

Architecture #205
--------------------
Epoch 0: Dev Accuracy: 0.2500; Dev Precision: 0.2500; Dev Recall: 1.0000; Dev Loss:0.9020
Epoch 1: Dev Accuracy: 0.5000; Dev Precision: 0.3333; Dev Recall: 1.0000; Dev Loss:0.9008
Epoch 2: Dev Accuracy: 0.5000; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.9118
Epoch 3: Dev Accuracy: 0.5000; Dev Precision:

Epoch 0: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8540
Epoch 1: Dev Accuracy: 0.5000; Dev Precision: 0.3333; Dev Recall: 1.0000; Dev Loss:1.1552
Epoch 2: Dev Accuracy: 0.5000; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.6981
Epoch 3: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.6315
Epoch 4: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.6641
Epoch 5: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.5011
Epoch 6: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.6110
Epoch 7: Dev Accuracy: 1.0000; Dev Precision: 1.0000; Dev Recall: 1.0000; Dev Loss:0.3934
Epoch 8: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.7203
Epoch 9: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.7404
--------------------

Architecture #214
--------------------
Epoch 0: Dev Accuracy: 0.2500; Dev Precision: 0.

Epoch 8: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8745
Epoch 9: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8746
--------------------

Architecture #222
--------------------
Epoch 0: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:1.1760
Epoch 1: Dev Accuracy: 0.5000; Dev Precision: 0.3333; Dev Recall: 1.0000; Dev Loss:0.8796
Epoch 2: Dev Accuracy: 0.5000; Dev Precision: 0.3333; Dev Recall: 1.0000; Dev Loss:0.9284
Epoch 3: Dev Accuracy: 0.5000; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.8731
Epoch 4: Dev Accuracy: 0.5000; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.8946
Epoch 5: Dev Accuracy: 0.5000; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.8879
Epoch 6: Dev Accuracy: 0.5000; Dev Precision: 0.3333; Dev Recall: 1.0000; Dev Loss:0.8266
Epoch 7: Dev Accuracy: 0.5000; Dev Precision: 0.3333; Dev Recall: 1.0000; Dev Loss:0.7971
Epoch 8: Dev Accuracy: 0.7500; Dev Precisi

Epoch 5: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8854
Epoch 6: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8851
Epoch 7: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8846
Epoch 8: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8846
Epoch 9: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8840
--------------------

Architecture #231
--------------------
Epoch 0: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.8794
Epoch 1: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8522
Epoch 2: Dev Accuracy: 0.7500; Dev Precision: 0.5000; Dev Recall: 1.0000; Dev Loss:0.8712
Epoch 3: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.7120
Epoch 4: Dev Accuracy: 0.7500; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.6876
Epoch 5: Dev Accuracy: 0.7500; Dev Precision: 0.

Epoch 2: Dev Accuracy: 0.5000; Dev Precision: 0.3333; Dev Recall: 1.0000; Dev Loss:0.8955
Epoch 3: Dev Accuracy: 0.5000; Dev Precision: 0.3333; Dev Recall: 1.0000; Dev Loss:0.8954
Epoch 4: Dev Accuracy: 0.5000; Dev Precision: 0.3333; Dev Recall: 1.0000; Dev Loss:0.8951
Epoch 5: Dev Accuracy: 0.2500; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.8960
Epoch 6: Dev Accuracy: 0.2500; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.8965
Epoch 7: Dev Accuracy: 0.2500; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.8961
Epoch 8: Dev Accuracy: 0.2500; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.8962
Epoch 9: Dev Accuracy: 0.2500; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.8968
--------------------



## Save results of model training

In [6]:
SAVE_PREFIX = '../results/SimpleNNCustom_'
df.to_csv(f'{SAVE_PREFIX}models.csv')
torch.save(best_acc[0], f'{SAVE_PREFIX}best_acc.pt')
torch.save(best_rec[0], f'{SAVE_PREFIX}best_rec.pt')
torch.save(best_prec[0], f'{SAVE_PREFIX}best_prec.pt')

  "type " + obj.__name__ + ". It won't be checked "
