# Simple NNs with custom embeddings

## Importing data

In [1]:
import copy
from itertools import product
from numpy import isnan
import pandas as pd
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
import torchtext.data as data
import sys
import warnings
warnings.filterwarnings('ignore')

sys.path.append('../data_pipeline/')
import preprocessing as pre
from training import TrainingModule

SEED = 1312
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [2]:
train_data, test_data, val_data, TEXT, LABEL = pre.get_data(
    'train_small.csv', 'val_small.csv', 'test_small.csv')

Connected!


In [3]:
USE_CUDA = torch.cuda.is_available()

TEXT.build_vocab(train_data)
LABEL.build_vocab(train_data)

BATCH_SIZE = 5

train_it, test_it, val_it = data.BucketIterator.splits(
    (train_data, test_data, val_data), 
    batch_size = BATCH_SIZE,
    sort_key=lambda x: len(x.alj_text),
    sort_within_batch=True,
    device = torch.device('cuda' if USE_CUDA else 'cpu'))

## Simple NN Model

In [4]:
class WordEmbAvg(nn.Module):
    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim, pad_idx, two_layers=True, dropout_p=0.0):
        
        super().__init__()
        
        # Define an embedding layer, a couple of linear layers, and 
        # the ReLU non-linearity.

        ##YOUR CODE HERE##
        self.embedding = nn.Embedding(input_dim, embedding_dim)
        if two_layers == True:
            self.linear1 = nn.Linear(embedding_dim, hidden_dim)
            self.linear2 = nn.Linear(hidden_dim, output_dim) 
        else:
            self.linear1 = nn.Linear(embedding_dim, output_dim)
            self.linear2 = None
        self.relu = nn.ReLU()
        self.drop_layer = nn.Dropout(p=dropout_p)

        
        
    def forward(self, text):

        ##YOUR CODE HERE##
        embedded = self.embedding(text)
        embedded = embedded.mean(0)
        if not self.linear2:
            linear1_output = self.linear1(embedded)
            output = self.relu(linear1_output)
            output = self.drop_layer(output)
            return output
        else:
            linear1_output = self.linear1(embedded)
            linear2_input = self.relu(linear1_output)
            output = self.linear2(linear2_input)
            output = self.drop_layer(output)
            return output

## Training models

In [None]:
# Store training results

df = pd.DataFrame(columns=['architecture', 'embeddings',
                           'hidden', 'dropouts',
                           'learning_rate', 'epochs',
                           'dev_acc', 'dev_prec', 'dev_recall',
                           'metric'])

# Model architecture parameters
INPUT_DIM = len(TEXT.vocab)
EMBEDDING_SIZES = [32, 64, 128, 256]
HIDDEN_SIZES = [10, 25, 40, 50]
OUTPUT_SIZE = 1
DROPOUTS = [0, 0.1, 0.25, 0.5, 0.75]
PADDING_IDX = TEXT.vocab.stoi[TEXT.pad_token]

# Model training hyperparameters
LEARNING_RATE = [0.01, 0.001, 0.0001]
train_len = 0
train_pos = 0
for batch in train_it:
    train_len += len(batch.decision_binary)
    train_pos += batch.decision_binary.sum().item()
POS_WEIGHT = torch.tensor([(train_len - train_pos) / train_pos])
if USE_CUDA:
    POS_WEIGHT = POS_WEIGHT.cuda()
EPOCHS = 10

# Iterator over various model parameters
param_iter = product (EMBEDDING_SIZES, HIDDEN_SIZES,
                      DROPOUTS, LEARNING_RATE)

# Magic loop
best_acc = (None, None)
best_rec = (None, None)
best_prec = (None, None)
for i, (embed_size, hidden_size, dropout, lr) in enumerate(param_iter):
    print(f'Architecture #{i}\n' + '-' * 20)
    model = WordEmbAvg(INPUT_DIM, embed_size, hidden_size,
                OUTPUT_SIZE, PADDING_IDX, dropout_p=dropout)
    
    tm = TrainingModule(model, lr, POS_WEIGHT, USE_CUDA, EPOCHS)
    
    best_models = tm.train_model(train_it, val_it)
    
    for metric, best_model in best_models.items():
        row = [i, embed_size, hidden_size, dropout,
               lr, EPOCHS, best_model.accuracy,
               best_model.precision, best_model.recall, metric]
        df.loc[len(df)] = row
        if best_acc[0] is None or isnan(best_acc[1]) or\
           best_model.accuracy > best_acc[1]:
            best_acc = (copy.deepcopy(best_model.model), best_model.accuracy)
        if best_rec[0] is None or isnan(best_rec[1]) or\
           best_model.recall > best_rec[1]:
            best_rec = (copy.deepcopy(best_model.model), best_model.recall)
        if best_prec[0] is None or isnan(best_prec[1]) or\
           best_model.precision > best_prec[1]:
            best_prec = (copy.deepcopy(best_model.model), best_model.precision)
    
    print('-' * 20 + '\n')


Architecture #0
--------------------
Epoch 0: Dev Accuracy: 0.8625; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7019
Epoch 1: Dev Accuracy: 0.3417; Dev Precision: 0.0889; Dev Recall: nan; Dev Loss:1.1038
Epoch 2: Dev Accuracy: 0.7292; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.0038
Epoch 3: Dev Accuracy: 0.8000; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.2269
Epoch 4: Dev Accuracy: 0.7292; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.6306
Epoch 5: Dev Accuracy: 0.7458; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.3888
Epoch 6: Dev Accuracy: 0.7292; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.5370
Epoch 7: Dev Accuracy: 0.6875; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.4988
Epoch 8: Dev Accuracy: 0.7458; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.5655
Epoch 9: Dev Accuracy: 0.7250; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.5742
--------------------

Architecture #1
--------------------
Epoch 0: Dev Accuracy: 0.1375; Dev Precision: 0.1375; Dev Rec

Epoch 0: Dev Accuracy: 0.6917; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7653
Epoch 1: Dev Accuracy: 0.6917; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7287
Epoch 2: Dev Accuracy: 0.5583; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7769
Epoch 3: Dev Accuracy: 0.6875; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7878
Epoch 4: Dev Accuracy: 0.6042; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8137
Epoch 5: Dev Accuracy: 0.6542; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8188
Epoch 6: Dev Accuracy: 0.7583; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.9065
Epoch 7: Dev Accuracy: 0.6708; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8567
Epoch 8: Dev Accuracy: 0.7250; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.9015
Epoch 9: Dev Accuracy: 0.7250; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.9300
--------------------

Architecture #10
--------------------
Epoch 0: Dev Accuracy: 0.8083; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7721
Epoch 1: Dev Acc

Epoch 0: Dev Accuracy: 0.1375; Dev Precision: 0.1375; Dev Recall: nan; Dev Loss:0.8966
Epoch 1: Dev Accuracy: 0.5875; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8923
Epoch 2: Dev Accuracy: 0.7625; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8061
Epoch 3: Dev Accuracy: 0.7250; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8789
Epoch 4: Dev Accuracy: 0.7208; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.0259
Epoch 5: Dev Accuracy: 0.7458; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.1682
Epoch 6: Dev Accuracy: 0.7750; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.1361
Epoch 7: Dev Accuracy: 0.7792; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.0586
Epoch 8: Dev Accuracy: 0.7458; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.1117
Epoch 9: Dev Accuracy: 0.7792; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.1504
--------------------

Architecture #19
--------------------
Epoch 0: Dev Accuracy: 0.6083; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7874
Epoch 1: Dev 

Epoch 1: Dev Accuracy: 0.6083; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7783
Epoch 2: Dev Accuracy: 0.8250; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.6878
Epoch 3: Dev Accuracy: 0.6583; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7524
Epoch 4: Dev Accuracy: 0.7792; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7085
Epoch 5: Dev Accuracy: 0.7292; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7304
Epoch 6: Dev Accuracy: 0.3458; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8852
Epoch 7: Dev Accuracy: 0.7625; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.6999
Epoch 8: Dev Accuracy: 0.8125; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7219
Epoch 9: Dev Accuracy: 0.5542; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8058
--------------------

Architecture #28
--------------------
Epoch 0: Dev Accuracy: 0.2750; Dev Precision: 0.1236; Dev Recall: nan; Dev Loss:0.8022
Epoch 1: Dev Accuracy: 0.3417; Dev Precision: 0.1278; Dev Recall: nan; Dev Loss:0.8028
Epoch 2: D

Epoch 2: Dev Accuracy: 0.7417; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7576
Epoch 3: Dev Accuracy: 0.6875; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.9249
Epoch 4: Dev Accuracy: 0.6958; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.0183
Epoch 5: Dev Accuracy: 0.7417; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.2961
Epoch 6: Dev Accuracy: 0.6875; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.2273
Epoch 7: Dev Accuracy: 0.7250; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.6149
Epoch 8: Dev Accuracy: 0.6375; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.3409
Epoch 9: Dev Accuracy: 0.7250; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.5092
--------------------

Architecture #37
--------------------
Epoch 0: Dev Accuracy: 0.7417; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7836
Epoch 1: Dev Accuracy: 0.6208; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7853
Epoch 2: Dev Accuracy: 0.6375; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7750
Epoch 3: Dev Acc

Epoch 3: Dev Accuracy: 0.7750; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.3305
Epoch 4: Dev Accuracy: 0.8000; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.2162
Epoch 5: Dev Accuracy: 0.7542; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.1760
Epoch 6: Dev Accuracy: 0.8500; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.4245
Epoch 7: Dev Accuracy: 0.7750; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.8142
Epoch 8: Dev Accuracy: 0.6750; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.7127
Epoch 9: Dev Accuracy: 0.7958; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.3312
--------------------

Architecture #46
--------------------
Epoch 0: Dev Accuracy: 0.7417; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7834
Epoch 1: Dev Accuracy: 0.5875; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7902
Epoch 2: Dev Accuracy: 0.6042; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7742
Epoch 3: Dev Accuracy: 0.6042; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7749
Epoch 4: Dev Acc

Epoch 4: Dev Accuracy: 0.7625; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.6732
Epoch 5: Dev Accuracy: 0.7917; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.6395
Epoch 6: Dev Accuracy: 0.7625; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.6628
Epoch 7: Dev Accuracy: 0.8292; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.6919
Epoch 8: Dev Accuracy: 0.7750; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7129
Epoch 9: Dev Accuracy: 0.8083; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7421
--------------------

Architecture #55
--------------------
Epoch 0: Dev Accuracy: 0.2917; Dev Precision: 0.1389; Dev Recall: nan; Dev Loss:0.8047
Epoch 1: Dev Accuracy: 0.3208; Dev Precision: 0.1458; Dev Recall: nan; Dev Loss:0.8071
Epoch 2: Dev Accuracy: 0.4583; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8053
Epoch 3: Dev Accuracy: 0.5042; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7937
Epoch 4: Dev Accuracy: 0.4208; Dev Precision: 0.1375; Dev Recall: nan; Dev Loss:0.8270
Epoch 5

Epoch 5: Dev Accuracy: 0.7042; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.2684
Epoch 6: Dev Accuracy: 0.5875; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.7393
Epoch 7: Dev Accuracy: 0.6708; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.2207
Epoch 8: Dev Accuracy: 0.7417; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.1907
Epoch 9: Dev Accuracy: 0.7583; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.2880
--------------------

Architecture #64
--------------------
Epoch 0: Dev Accuracy: 0.1375; Dev Precision: 0.1375; Dev Recall: nan; Dev Loss:0.8159
Epoch 1: Dev Accuracy: 0.6250; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7905
Epoch 2: Dev Accuracy: 0.6083; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7831
Epoch 3: Dev Accuracy: 0.6083; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7834
Epoch 4: Dev Accuracy: 0.7250; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7315
Epoch 5: Dev Accuracy: 0.6583; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7827
Epoch 6: Dev 

Epoch 6: Dev Accuracy: 0.7250; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7837
Epoch 7: Dev Accuracy: 0.6958; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8090
Epoch 8: Dev Accuracy: 0.5042; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8717
Epoch 9: Dev Accuracy: 0.7292; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7804
--------------------

Architecture #73
--------------------
Epoch 0: Dev Accuracy: 0.8625; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7722
Epoch 1: Dev Accuracy: 0.4750; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8061
Epoch 2: Dev Accuracy: 0.4750; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8032
Epoch 3: Dev Accuracy: 0.4917; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7993
Epoch 4: Dev Accuracy: 0.5417; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7949
Epoch 5: Dev Accuracy: 0.3042; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8153
Epoch 6: Dev Accuracy: 0.7250; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7889
Epoch 7: Dev Acc

Epoch 7: Dev Accuracy: 0.7417; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.1225
Epoch 8: Dev Accuracy: 0.7625; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.2795
Epoch 9: Dev Accuracy: 0.7625; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.5467
--------------------

Architecture #82
--------------------
Epoch 0: Dev Accuracy: 0.2708; Dev Precision: 0.1361; Dev Recall: nan; Dev Loss:0.8161
Epoch 1: Dev Accuracy: 0.5250; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8022
Epoch 2: Dev Accuracy: 0.6083; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7891
Epoch 3: Dev Accuracy: 0.7083; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7605
Epoch 4: Dev Accuracy: 0.6542; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7878
Epoch 5: Dev Accuracy: 0.5542; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8578
Epoch 6: Dev Accuracy: 0.6583; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8585
Epoch 7: Dev Accuracy: 0.6417; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8976
Epoch 8: Dev 

Epoch 8: Dev Accuracy: 0.7792; Dev Precision: nan; Dev Recall: nan; Dev Loss:2.0249
Epoch 9: Dev Accuracy: 0.7792; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.8913
--------------------

Architecture #91
--------------------
Epoch 0: Dev Accuracy: 0.3917; Dev Precision: 0.1917; Dev Recall: nan; Dev Loss:0.8221
Epoch 1: Dev Accuracy: 0.4917; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8232
Epoch 2: Dev Accuracy: 0.6750; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7825
Epoch 3: Dev Accuracy: 0.7083; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7653
Epoch 4: Dev Accuracy: 0.7083; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8329
Epoch 5: Dev Accuracy: 0.6917; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.9017
Epoch 6: Dev Accuracy: 0.6750; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.9656
Epoch 7: Dev Accuracy: 0.6750; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.0205
Epoch 8: Dev Accuracy: 0.7250; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.0332
Epoch 9: Dev 

Epoch 9: Dev Accuracy: 0.7208; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8436
--------------------

Architecture #100
--------------------
Epoch 0: Dev Accuracy: 0.7250; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7838
Epoch 1: Dev Accuracy: 0.7417; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7792
Epoch 2: Dev Accuracy: 0.4042; Dev Precision: 0.1722; Dev Recall: nan; Dev Loss:0.8256
Epoch 3: Dev Accuracy: 0.5042; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8269
Epoch 4: Dev Accuracy: 0.7417; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7779
Epoch 5: Dev Accuracy: 0.7250; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7200
Epoch 6: Dev Accuracy: 0.7083; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7842
Epoch 7: Dev Accuracy: 0.7292; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7508
Epoch 8: Dev Accuracy: 0.7292; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7737
Epoch 9: Dev Accuracy: 0.6708; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7972
------------

Epoch 9: Dev Accuracy: 0.7583; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.4591
--------------------

Architecture #109
--------------------
Epoch 0: Dev Accuracy: 0.2042; Dev Precision: 0.1417; Dev Recall: nan; Dev Loss:0.8452
Epoch 1: Dev Accuracy: 0.4042; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8177
Epoch 2: Dev Accuracy: 0.3875; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8618
Epoch 3: Dev Accuracy: 0.5208; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8529
Epoch 4: Dev Accuracy: 0.6625; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8349
Epoch 5: Dev Accuracy: 0.5917; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.9242
Epoch 6: Dev Accuracy: 0.6625; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.9551
Epoch 7: Dev Accuracy: 0.6583; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.0114
Epoch 8: Dev Accuracy: 0.6417; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.0700
Epoch 9: Dev Accuracy: 0.6417; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.1342
------------

Epoch 0: Dev Accuracy: 0.6375; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7922
Epoch 1: Dev Accuracy: 0.5750; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7944
Epoch 2: Dev Accuracy: 0.4708; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8091
Epoch 3: Dev Accuracy: 0.3375; Dev Precision: 0.1444; Dev Recall: nan; Dev Loss:0.8170
Epoch 4: Dev Accuracy: 0.6375; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7888
Epoch 5: Dev Accuracy: 0.5375; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7958
Epoch 6: Dev Accuracy: 0.4375; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8153
Epoch 7: Dev Accuracy: 0.7625; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7044
Epoch 8: Dev Accuracy: 0.7458; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7556
Epoch 9: Dev Accuracy: 0.6583; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7744
--------------------

Architecture #119
--------------------
Epoch 0: Dev Accuracy: 0.1375; Dev Precision: 0.1375; Dev Recall: nan; Dev Loss:0.8108
Epoch 1: 

Epoch 1: Dev Accuracy: 0.7625; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7732
Epoch 2: Dev Accuracy: 0.7458; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7739
Epoch 3: Dev Accuracy: 0.6917; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7532
Epoch 4: Dev Accuracy: 0.7625; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7286
Epoch 5: Dev Accuracy: 0.7292; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7431
Epoch 6: Dev Accuracy: 0.7667; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7221
Epoch 7: Dev Accuracy: 0.7125; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7482
Epoch 8: Dev Accuracy: 0.7833; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7273
Epoch 9: Dev Accuracy: 0.7292; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7512
--------------------

Architecture #128
--------------------
Epoch 0: Dev Accuracy: 0.8625; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7275
Epoch 1: Dev Accuracy: 0.8625; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7331
Epoch 2: Dev Ac

Epoch 2: Dev Accuracy: 0.7417; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7346
Epoch 3: Dev Accuracy: 0.7208; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7417
Epoch 4: Dev Accuracy: 0.7042; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8086
Epoch 5: Dev Accuracy: 0.7417; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8909
Epoch 6: Dev Accuracy: 0.6917; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.8788
Epoch 7: Dev Accuracy: 0.6917; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.9932
Epoch 8: Dev Accuracy: 0.7083; Dev Precision: nan; Dev Recall: nan; Dev Loss:1.0348
Epoch 9: Dev Accuracy: 0.6917; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.9982
--------------------

Architecture #137
--------------------
Epoch 0: Dev Accuracy: 0.8625; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7596
Epoch 1: Dev Accuracy: 0.8625; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7619
Epoch 2: Dev Accuracy: 0.8625; Dev Precision: nan; Dev Recall: nan; Dev Loss:0.7645
Epoch 3: Dev Ac

## Save results of model training

In [None]:
SAVE_PREFIX = '../results/SimpleNNCustom_'
df.to_csv(f'{SAVE_PREFIX}models.csv')
torch.save(best_acc[0], f'{SAVE_PREFIX}best_acc.pt')
torch.save(best_rec[0], f'{SAVE_PREFIX}best_rec.pt')
torch.save(best_prec[0], f'{SAVE_PREFIX}best_prec.pt')