# Simple NNs with custom embeddings

## Importing data

In [1]:
import copy
from itertools import product
from numpy import isnan
import pandas as pd
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
import torchtext.data as data
import sys
import warnings
warnings.filterwarnings('ignore')

sys.path.append('../data_pipeline/')
import preprocessing as pre
from training import TrainingModule

SEED = 1312
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [2]:
train_data, test_data, val_data, TEXT, LABEL = pre.get_data(
    'train_small.csv', 'val_small.csv', 'test_small.csv')

Connected!


In [3]:
USE_CUDA = torch.cuda.is_available()

TEXT.build_vocab(train_data)
LABEL.build_vocab(train_data)

BATCH_SIZE = 5

train_it, test_it, val_it = data.BucketIterator.splits(
    (train_data, test_data, val_data), 
    batch_size = BATCH_SIZE,
    sort_key=lambda x: len(x.alj_text),
    sort_within_batch=True,
    device = torch.device('cuda' if USE_CUDA else 'cpu'))

## Simple NN Model

In [4]:
class WordEmbAvg(nn.Module):
    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim, pad_idx, two_layers=True, dropout_p=0.0):
        
        super().__init__()
        
        # Define an embedding layer, a couple of linear layers, and 
        # the ReLU non-linearity.

        ##YOUR CODE HERE##
        self.embedding = nn.Embedding(input_dim, embedding_dim)
        if two_layers == True:
            self.linear1 = nn.Linear(embedding_dim, hidden_dim)
            self.linear2 = nn.Linear(hidden_dim, output_dim) 
        else:
            self.linear1 = nn.Linear(embedding_dim, output_dim)
            self.linear2 = None
        self.relu = nn.ReLU()
        self.drop_layer = nn.Dropout(p=dropout_p)

        
        
    def forward(self, text):

        ##YOUR CODE HERE##
        embedded = self.embedding(text)
        embedded = embedded.mean(0)
        if not self.linear2:
            linear1_output = self.linear1(embedded)
            output = self.relu(linear1_output)
            output = self.drop_layer(output)
            return output
        else:
            linear1_output = self.linear1(embedded)
            linear2_input = self.relu(linear1_output)
            output = self.linear2(linear2_input)
            output = self.drop_layer(output)
            return output

## Training models

In [5]:
# Store training results

df = pd.DataFrame(columns=['architecture', 'embeddings',
                           'hidden', 'dropouts',
                           'learning_rate', 'epochs',
                           'dev_acc', 'dev_prec', 'dev_recall',
                           'metric'])

# Model architecture parameters
INPUT_DIM = len(TEXT.vocab)
EMBEDDING_SIZES = [32, 64, 128, 256]
HIDDEN_SIZES = [10, 25, 40, 50]
OUTPUT_SIZE = 1
DROPOUTS = [0, 0.1, 0.25, 0.5, 0.75]
PADDING_IDX = TEXT.vocab.stoi[TEXT.pad_token]

# Model training hyperparameters
LEARNING_RATE = [0.01, 0.001, 0.0001]
train_len = 0
train_pos = 0
for batch in train_it:
    train_len += len(batch.decision_binary)
    train_pos += batch.decision_binary.sum().item()
POS_WEIGHT = torch.tensor([(train_len - train_pos) / train_pos])
if USE_CUDA:
    POS_WEIGHT = POS_WEIGHT.cuda()
EPOCHS = 10

# Iterator over various model parameters
param_iter = product (EMBEDDING_SIZES, HIDDEN_SIZES,
                      DROPOUTS, LEARNING_RATE)

# Magic loop
best_acc = (None, None)
best_rec = (None, None)
best_prec = (None, None)
for i, (embed_size, hidden_size, dropout, lr) in enumerate(param_iter):
    print(f'Architecture #{i}\n' + '-' * 20)
    model = WordEmbAvg(INPUT_DIM, embed_size, hidden_size,
                OUTPUT_SIZE, PADDING_IDX, dropout_p=dropout)
    
    tm = TrainingModule(model, lr, POS_WEIGHT, USE_CUDA, EPOCHS)
    
    best_models = tm.train_model(train_it, val_it)
    
    for metric, best_model in best_models.items():
        row = [i, embed_size, hidden_size, dropout,
               lr, EPOCHS, best_model.accuracy,
               best_model.precision, best_model.recall, metric]
        df.loc[len(df)] = row
        if best_acc[0] is None or isnan(best_acc[1]) or\
           best_model.accuracy > best_acc[1]:
            best_acc = (copy.deepcopy(best_model.model), best_model.accuracy)
        if best_rec[0] is None or isnan(best_rec[1]) or\
           best_model.recall > best_rec[1]:
            best_rec = (copy.deepcopy(best_model.model), best_model.recall)
        if best_prec[0] is None or isnan(best_prec[1]) or\
           best_model.precision > best_prec[1]:
            best_prec = (copy.deepcopy(best_model.model), best_model.precision)
    
    print('-' * 20 + '\n')


Architecture #0
--------------------
Epoch 0: Dev Accuracy: 0.5763; Dev Precision: 0.1600; Dev Recall: 0.5000; Dev Loss:0.1593
Epoch 1: Dev Accuracy: 0.7458; Dev Precision: 0.1111; Dev Recall: 0.1250; Dev Loss:0.1496
Epoch 2: Dev Accuracy: 0.6271; Dev Precision: 0.1500; Dev Recall: 0.3750; Dev Loss:0.2529
Epoch 3: Dev Accuracy: 0.7288; Dev Precision: 0.1667; Dev Recall: 0.2500; Dev Loss:0.2502
Epoch 4: Dev Accuracy: 0.7288; Dev Precision: 0.1667; Dev Recall: 0.2500; Dev Loss:0.3061
Epoch 5: Dev Accuracy: 0.7288; Dev Precision: 0.1000; Dev Recall: 0.1250; Dev Loss:0.3075
Epoch 6: Dev Accuracy: 0.7288; Dev Precision: 0.1667; Dev Recall: 0.2500; Dev Loss:0.3691
Epoch 7: Dev Accuracy: 0.7627; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.4013
Epoch 8: Dev Accuracy: 0.7627; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.3366
Epoch 9: Dev Accuracy: 0.7288; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.3904
--------------------

Architecture #1
--------------------
Epoc

Epoch 6: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1765
Epoch 7: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1761
Epoch 8: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1755
Epoch 9: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1746
--------------------

Architecture #9
--------------------
Epoch 0: Dev Accuracy: 0.4576; Dev Precision: 0.1667; Dev Recall: 0.7500; Dev Loss:0.1637
Epoch 1: Dev Accuracy: 0.7119; Dev Precision: 0.2632; Dev Recall: 0.6250; Dev Loss:0.1530
Epoch 2: Dev Accuracy: 0.7458; Dev Precision: 0.2308; Dev Recall: 0.3750; Dev Loss:0.1609
Epoch 3: Dev Accuracy: 0.6949; Dev Precision: 0.1875; Dev Recall: 0.3750; Dev Loss:0.1587
Epoch 4: Dev Accuracy: 0.7458; Dev Precision: 0.2308; Dev Recall: 0.3750; Dev Loss:0.1565
Epoch 5: Dev Accuracy: 0.6441; Dev Precision: 0.1905; Dev Recall: 0.5000; Dev Loss:0.1680
Epoch 6: Dev Accuracy: 0.7119; Dev Precis

Epoch 2: Dev Accuracy: 0.1525; Dev Precision: 0.1379; Dev Recall: 1.0000; Dev Loss:0.1638
Epoch 3: Dev Accuracy: 0.1525; Dev Precision: 0.1379; Dev Recall: 1.0000; Dev Loss:0.1636
Epoch 4: Dev Accuracy: 0.2034; Dev Precision: 0.1455; Dev Recall: 1.0000; Dev Loss:0.1634
Epoch 5: Dev Accuracy: 0.2712; Dev Precision: 0.1569; Dev Recall: 1.0000; Dev Loss:0.1629
Epoch 6: Dev Accuracy: 0.2881; Dev Precision: 0.1600; Dev Recall: 1.0000; Dev Loss:0.1628
Epoch 7: Dev Accuracy: 0.3051; Dev Precision: 0.1633; Dev Recall: 1.0000; Dev Loss:0.1626
Epoch 8: Dev Accuracy: 0.3559; Dev Precision: 0.1591; Dev Recall: 0.8750; Dev Loss:0.1623
Epoch 9: Dev Accuracy: 0.3729; Dev Precision: 0.1628; Dev Recall: 0.8750; Dev Loss:0.1623
--------------------

Architecture #18
--------------------
Epoch 0: Dev Accuracy: 0.6780; Dev Precision: 0.2105; Dev Recall: 0.5000; Dev Loss:0.1630
Epoch 1: Dev Accuracy: 0.5763; Dev Precision: 0.1304; Dev Recall: 0.3750; Dev Loss:0.1856
Epoch 2: Dev Accuracy: 0.5932; Dev Preci

Epoch 8: Dev Accuracy: 0.7627; Dev Precision: 0.2000; Dev Recall: 0.2500; Dev Loss:0.1463
Epoch 9: Dev Accuracy: 0.7627; Dev Precision: 0.2000; Dev Recall: 0.2500; Dev Loss:0.1418
--------------------

Architecture #26
--------------------
Epoch 0: Dev Accuracy: 0.1356; Dev Precision: 0.1228; Dev Recall: 0.8750; Dev Loss:0.1663
Epoch 1: Dev Accuracy: 0.1525; Dev Precision: 0.1379; Dev Recall: 1.0000; Dev Loss:0.1661
Epoch 2: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1660
Epoch 3: Dev Accuracy: 0.1186; Dev Precision: 0.1207; Dev Recall: 0.8750; Dev Loss:0.1656
Epoch 4: Dev Accuracy: 0.1186; Dev Precision: 0.1207; Dev Recall: 0.8750; Dev Loss:0.1652
Epoch 5: Dev Accuracy: 0.1186; Dev Precision: 0.1207; Dev Recall: 0.8750; Dev Loss:0.1647
Epoch 6: Dev Accuracy: 0.1525; Dev Precision: 0.1250; Dev Recall: 0.8750; Dev Loss:0.1641
Epoch 7: Dev Accuracy: 0.1186; Dev Precision: 0.1207; Dev Recall: 0.8750; Dev Loss:0.1646
Epoch 8: Dev Accuracy: 0.1695; Dev Preci

Epoch 4: Dev Accuracy: 0.6441; Dev Precision: 0.2174; Dev Recall: 0.6250; Dev Loss:0.1633
Epoch 5: Dev Accuracy: 0.6949; Dev Precision: 0.2500; Dev Recall: 0.6250; Dev Loss:0.1684
Epoch 6: Dev Accuracy: 0.6780; Dev Precision: 0.2381; Dev Recall: 0.6250; Dev Loss:0.1800
Epoch 7: Dev Accuracy: 0.6780; Dev Precision: 0.1765; Dev Recall: 0.3750; Dev Loss:0.1880
Epoch 8: Dev Accuracy: 0.6780; Dev Precision: 0.1333; Dev Recall: 0.2500; Dev Loss:0.1983
Epoch 9: Dev Accuracy: 0.6780; Dev Precision: 0.1333; Dev Recall: 0.2500; Dev Loss:0.2060
--------------------

Architecture #35
--------------------
Epoch 0: Dev Accuracy: 0.7627; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.1589
Epoch 1: Dev Accuracy: 0.7627; Dev Precision: 0.1250; Dev Recall: 0.1250; Dev Loss:0.1595
Epoch 2: Dev Accuracy: 0.7627; Dev Precision: 0.1250; Dev Recall: 0.1250; Dev Loss:0.1598
Epoch 3: Dev Accuracy: 0.7458; Dev Precision: 0.1111; Dev Recall: 0.1250; Dev Loss:0.1602
Epoch 4: Dev Accuracy: 0.7458; Dev Preci

Epoch 0: Dev Accuracy: 0.7797; Dev Precision: 0.1429; Dev Recall: 0.1250; Dev Loss:0.1602
Epoch 1: Dev Accuracy: 0.7119; Dev Precision: 0.0909; Dev Recall: 0.1250; Dev Loss:0.1609
Epoch 2: Dev Accuracy: 0.2881; Dev Precision: 0.1458; Dev Recall: 0.8750; Dev Loss:0.1666
Epoch 3: Dev Accuracy: 0.7627; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.1576
Epoch 4: Dev Accuracy: 0.5254; Dev Precision: 0.1154; Dev Recall: 0.3750; Dev Loss:0.1626
Epoch 5: Dev Accuracy: 0.7119; Dev Precision: 0.1538; Dev Recall: 0.2500; Dev Loss:0.1569
Epoch 6: Dev Accuracy: 0.7288; Dev Precision: 0.1000; Dev Recall: 0.1250; Dev Loss:0.1567
Epoch 7: Dev Accuracy: 0.5593; Dev Precision: 0.0909; Dev Recall: 0.2500; Dev Loss:0.1608
Epoch 8: Dev Accuracy: 0.6102; Dev Precision: 0.1429; Dev Recall: 0.3750; Dev Loss:0.1596
Epoch 9: Dev Accuracy: 0.6271; Dev Precision: 0.1111; Dev Recall: 0.2500; Dev Loss:0.1570
--------------------

Architecture #44
--------------------
Epoch 0: Dev Accuracy: 0.8644; Dev Preci

Epoch 6: Dev Accuracy: 0.6102; Dev Precision: 0.2000; Dev Recall: 0.6250; Dev Loss:0.3103
Epoch 7: Dev Accuracy: 0.8305; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.3290
Epoch 8: Dev Accuracy: 0.8136; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.2582
Epoch 9: Dev Accuracy: 0.6949; Dev Precision: 0.1875; Dev Recall: 0.3750; Dev Loss:0.2298
--------------------

Architecture #52
--------------------
Epoch 0: Dev Accuracy: 0.5424; Dev Precision: 0.1935; Dev Recall: 0.7500; Dev Loss:0.1617
Epoch 1: Dev Accuracy: 0.5085; Dev Precision: 0.2000; Dev Recall: 0.8750; Dev Loss:0.1635
Epoch 2: Dev Accuracy: 0.7288; Dev Precision: 0.2778; Dev Recall: 0.6250; Dev Loss:0.1552
Epoch 3: Dev Accuracy: 0.6949; Dev Precision: 0.2500; Dev Recall: 0.6250; Dev Loss:0.1554
Epoch 4: Dev Accuracy: 0.6271; Dev Precision: 0.2083; Dev Recall: 0.6250; Dev Loss:0.1615
Epoch 5: Dev Accuracy: 0.6271; Dev Precision: 0.2083; Dev Recall: 0.6250; Dev Loss:0.1681
Epoch 6: Dev Accuracy: 0.6441; Dev Preci

Epoch 2: Dev Accuracy: 0.6441; Dev Precision: 0.1579; Dev Recall: 0.3750; Dev Loss:0.2115
Epoch 3: Dev Accuracy: 0.7458; Dev Precision: 0.1111; Dev Recall: 0.1250; Dev Loss:0.1977
Epoch 4: Dev Accuracy: 0.7966; Dev Precision: 0.2500; Dev Recall: 0.2500; Dev Loss:0.2507
Epoch 5: Dev Accuracy: 0.6949; Dev Precision: 0.1875; Dev Recall: 0.3750; Dev Loss:0.2810
Epoch 6: Dev Accuracy: 0.7288; Dev Precision: 0.1667; Dev Recall: 0.2500; Dev Loss:0.2906
Epoch 7: Dev Accuracy: 0.7119; Dev Precision: 0.1538; Dev Recall: 0.2500; Dev Loss:0.3079
Epoch 8: Dev Accuracy: 0.6610; Dev Precision: 0.1667; Dev Recall: 0.3750; Dev Loss:0.3068
Epoch 9: Dev Accuracy: 0.6780; Dev Precision: 0.1333; Dev Recall: 0.2500; Dev Loss:0.4056
--------------------

Architecture #61
--------------------
Epoch 0: Dev Accuracy: 0.2203; Dev Precision: 0.1346; Dev Recall: 0.8750; Dev Loss:0.1649
Epoch 1: Dev Accuracy: 0.7458; Dev Precision: 0.1111; Dev Recall: 0.1250; Dev Loss:0.1546
Epoch 2: Dev Accuracy: 0.7627; Dev Preci

Epoch 9: Dev Accuracy: 0.5763; Dev Precision: 0.1600; Dev Recall: 0.5000; Dev Loss:0.1612
--------------------

Architecture #69
--------------------
Epoch 0: Dev Accuracy: 0.7458; Dev Precision: 0.1818; Dev Recall: 0.2500; Dev Loss:0.1600
Epoch 1: Dev Accuracy: 0.6271; Dev Precision: 0.1500; Dev Recall: 0.3750; Dev Loss:0.1635
Epoch 2: Dev Accuracy: 0.6102; Dev Precision: 0.1739; Dev Recall: 0.5000; Dev Loss:0.1621
Epoch 3: Dev Accuracy: 0.4407; Dev Precision: 0.0968; Dev Recall: 0.3750; Dev Loss:0.1911
Epoch 4: Dev Accuracy: 0.6271; Dev Precision: 0.0625; Dev Recall: 0.1250; Dev Loss:0.1786
Epoch 5: Dev Accuracy: 0.6610; Dev Precision: 0.1667; Dev Recall: 0.3750; Dev Loss:0.1646
Epoch 6: Dev Accuracy: 0.7288; Dev Precision: 0.1000; Dev Recall: 0.1250; Dev Loss:0.1498
Epoch 7: Dev Accuracy: 0.7458; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.1492
Epoch 8: Dev Accuracy: 0.7627; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.1409
Epoch 9: Dev Accuracy: 0.7458; Dev Preci

Epoch 5: Dev Accuracy: 0.8305; Dev Precision: 0.2500; Dev Recall: 0.1250; Dev Loss:0.1585
Epoch 6: Dev Accuracy: 0.8305; Dev Precision: 0.2500; Dev Recall: 0.1250; Dev Loss:0.1586
Epoch 7: Dev Accuracy: 0.8305; Dev Precision: 0.2500; Dev Recall: 0.1250; Dev Loss:0.1585
Epoch 8: Dev Accuracy: 0.8305; Dev Precision: 0.2500; Dev Recall: 0.1250; Dev Loss:0.1585
Epoch 9: Dev Accuracy: 0.8475; Dev Precision: 0.3333; Dev Recall: 0.1250; Dev Loss:0.1584
--------------------

Architecture #78
--------------------
Epoch 0: Dev Accuracy: 0.7458; Dev Precision: 0.1818; Dev Recall: 0.2500; Dev Loss:0.1494
Epoch 1: Dev Accuracy: 0.7797; Dev Precision: 0.2727; Dev Recall: 0.3750; Dev Loss:0.1394
Epoch 2: Dev Accuracy: 0.6780; Dev Precision: 0.1765; Dev Recall: 0.3750; Dev Loss:0.1658
Epoch 3: Dev Accuracy: 0.7627; Dev Precision: 0.2000; Dev Recall: 0.2500; Dev Loss:0.1975
Epoch 4: Dev Accuracy: 0.7119; Dev Precision: 0.0909; Dev Recall: 0.1250; Dev Loss:0.2414
Epoch 5: Dev Accuracy: 0.8305; Dev Preci

Epoch 1: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1664
Epoch 2: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1660
Epoch 3: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1653
Epoch 4: Dev Accuracy: 0.1525; Dev Precision: 0.1379; Dev Recall: 1.0000; Dev Loss:0.1648
Epoch 5: Dev Accuracy: 0.1525; Dev Precision: 0.1379; Dev Recall: 1.0000; Dev Loss:0.1646
Epoch 6: Dev Accuracy: 0.2542; Dev Precision: 0.1538; Dev Recall: 1.0000; Dev Loss:0.1634
Epoch 7: Dev Accuracy: 0.3220; Dev Precision: 0.1667; Dev Recall: 1.0000; Dev Loss:0.1631
Epoch 8: Dev Accuracy: 0.3220; Dev Precision: 0.1667; Dev Recall: 1.0000; Dev Loss:0.1633
Epoch 9: Dev Accuracy: 0.3220; Dev Precision: 0.1667; Dev Recall: 1.0000; Dev Loss:0.1633
--------------------

Architecture #87
--------------------
Epoch 0: Dev Accuracy: 0.4237; Dev Precision: 0.0667; Dev Recall: 0.2500; Dev Loss:0.1623
Epoch 1: Dev Accuracy: 0.1695; Dev Preci

Epoch 7: Dev Accuracy: 0.6610; Dev Precision: 0.1667; Dev Recall: 0.3750; Dev Loss:0.2066
Epoch 8: Dev Accuracy: 0.5593; Dev Precision: 0.1538; Dev Recall: 0.5000; Dev Loss:0.2680
Epoch 9: Dev Accuracy: 0.7627; Dev Precision: 0.2500; Dev Recall: 0.3750; Dev Loss:0.2113
--------------------

Architecture #95
--------------------
Epoch 0: Dev Accuracy: 0.3729; Dev Precision: 0.1081; Dev Recall: 0.5000; Dev Loss:0.1633
Epoch 1: Dev Accuracy: 0.3898; Dev Precision: 0.1111; Dev Recall: 0.5000; Dev Loss:0.1625
Epoch 2: Dev Accuracy: 0.4407; Dev Precision: 0.1429; Dev Recall: 0.6250; Dev Loss:0.1624
Epoch 3: Dev Accuracy: 0.4915; Dev Precision: 0.1562; Dev Recall: 0.6250; Dev Loss:0.1621
Epoch 4: Dev Accuracy: 0.4915; Dev Precision: 0.1562; Dev Recall: 0.6250; Dev Loss:0.1625
Epoch 5: Dev Accuracy: 0.5254; Dev Precision: 0.1667; Dev Recall: 0.6250; Dev Loss:0.1623
Epoch 6: Dev Accuracy: 0.4915; Dev Precision: 0.1333; Dev Recall: 0.5000; Dev Loss:0.1622
Epoch 7: Dev Accuracy: 0.6102; Dev Preci

Epoch 2: Dev Accuracy: 0.3390; Dev Precision: 0.1556; Dev Recall: 0.8750; Dev Loss:0.1651
Epoch 3: Dev Accuracy: 0.6949; Dev Precision: 0.2500; Dev Recall: 0.6250; Dev Loss:0.1589
Epoch 4: Dev Accuracy: 0.7458; Dev Precision: 0.1818; Dev Recall: 0.2500; Dev Loss:0.1543
Epoch 5: Dev Accuracy: 0.6610; Dev Precision: 0.2500; Dev Recall: 0.7500; Dev Loss:0.1601
Epoch 6: Dev Accuracy: 0.3051; Dev Precision: 0.1489; Dev Recall: 0.8750; Dev Loss:0.1701
Epoch 7: Dev Accuracy: 0.6610; Dev Precision: 0.1667; Dev Recall: 0.3750; Dev Loss:0.1585
Epoch 8: Dev Accuracy: 0.7288; Dev Precision: 0.1667; Dev Recall: 0.2500; Dev Loss:0.1556
Epoch 9: Dev Accuracy: 0.4576; Dev Precision: 0.1842; Dev Recall: 0.8750; Dev Loss:0.1716
--------------------

Architecture #104
--------------------
Epoch 0: Dev Accuracy: 0.1864; Dev Precision: 0.1296; Dev Recall: 0.8750; Dev Loss:0.1656
Epoch 1: Dev Accuracy: 0.2203; Dev Precision: 0.1346; Dev Recall: 0.8750; Dev Loss:0.1648
Epoch 2: Dev Accuracy: 0.2373; Dev Prec

Epoch 8: Dev Accuracy: 0.7797; Dev Precision: 0.2222; Dev Recall: 0.2500; Dev Loss:0.3124
Epoch 9: Dev Accuracy: 0.7627; Dev Precision: 0.2000; Dev Recall: 0.2500; Dev Loss:0.2652
--------------------

Architecture #112
--------------------
Epoch 0: Dev Accuracy: 0.6610; Dev Precision: 0.1250; Dev Recall: 0.2500; Dev Loss:0.1569
Epoch 1: Dev Accuracy: 0.7288; Dev Precision: 0.1667; Dev Recall: 0.2500; Dev Loss:0.1555
Epoch 2: Dev Accuracy: 0.6949; Dev Precision: 0.1429; Dev Recall: 0.2500; Dev Loss:0.1524
Epoch 3: Dev Accuracy: 0.7627; Dev Precision: 0.2000; Dev Recall: 0.2500; Dev Loss:0.1483
Epoch 4: Dev Accuracy: 0.7458; Dev Precision: 0.1818; Dev Recall: 0.2500; Dev Loss:0.1607
Epoch 5: Dev Accuracy: 0.7119; Dev Precision: 0.2000; Dev Recall: 0.3750; Dev Loss:0.1645
Epoch 6: Dev Accuracy: 0.5424; Dev Precision: 0.1481; Dev Recall: 0.5000; Dev Loss:0.1929
Epoch 7: Dev Accuracy: 0.6949; Dev Precision: 0.1429; Dev Recall: 0.2500; Dev Loss:0.1776
Epoch 8: Dev Accuracy: 0.7288; Dev Prec

Epoch 3: Dev Accuracy: 0.6610; Dev Precision: 0.2273; Dev Recall: 0.6250; Dev Loss:0.3043
Epoch 4: Dev Accuracy: 0.6610; Dev Precision: 0.1667; Dev Recall: 0.3750; Dev Loss:0.3134
Epoch 5: Dev Accuracy: 0.7119; Dev Precision: 0.2000; Dev Recall: 0.3750; Dev Loss:0.3686
Epoch 6: Dev Accuracy: 0.7458; Dev Precision: 0.2308; Dev Recall: 0.3750; Dev Loss:0.4000
Epoch 7: Dev Accuracy: 0.7288; Dev Precision: 0.2143; Dev Recall: 0.3750; Dev Loss:0.4058
Epoch 8: Dev Accuracy: 0.7288; Dev Precision: 0.2143; Dev Recall: 0.3750; Dev Loss:0.4560
Epoch 9: Dev Accuracy: 0.7458; Dev Precision: 0.1818; Dev Recall: 0.2500; Dev Loss:0.4534
--------------------

Architecture #121
--------------------
Epoch 0: Dev Accuracy: 0.6102; Dev Precision: 0.1429; Dev Recall: 0.3750; Dev Loss:0.1612
Epoch 1: Dev Accuracy: 0.6780; Dev Precision: 0.2105; Dev Recall: 0.5000; Dev Loss:0.1574
Epoch 2: Dev Accuracy: 0.5593; Dev Precision: 0.1538; Dev Recall: 0.5000; Dev Loss:0.1626
Epoch 3: Dev Accuracy: 0.5932; Dev Prec

Epoch 9: Dev Accuracy: 0.7797; Dev Precision: 0.2222; Dev Recall: 0.2500; Dev Loss:0.1558
--------------------

Architecture #129
--------------------
Epoch 0: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1689
Epoch 1: Dev Accuracy: 0.6949; Dev Precision: 0.1875; Dev Recall: 0.3750; Dev Loss:0.1535
Epoch 2: Dev Accuracy: 0.6780; Dev Precision: 0.2105; Dev Recall: 0.5000; Dev Loss:0.1786
Epoch 3: Dev Accuracy: 0.6949; Dev Precision: 0.1875; Dev Recall: 0.3750; Dev Loss:0.1426
Epoch 4: Dev Accuracy: 0.5593; Dev Precision: 0.1538; Dev Recall: 0.5000; Dev Loss:0.2320
Epoch 5: Dev Accuracy: 0.8136; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.2265
Epoch 6: Dev Accuracy: 0.7119; Dev Precision: 0.2000; Dev Recall: 0.3750; Dev Loss:0.1590
Epoch 7: Dev Accuracy: 0.6780; Dev Precision: 0.1333; Dev Recall: 0.2500; Dev Loss:0.1565
Epoch 8: Dev Accuracy: 0.7966; Dev Precision: 0.1667; Dev Recall: 0.1250; Dev Loss:0.1572
Epoch 9: Dev Accuracy: 0.8305; Dev Prec

Epoch 5: Dev Accuracy: 0.8475; Dev Precision: 0.3333; Dev Recall: 0.1250; Dev Loss:0.1568
Epoch 6: Dev Accuracy: 0.8305; Dev Precision: 0.2500; Dev Recall: 0.1250; Dev Loss:0.1568
Epoch 7: Dev Accuracy: 0.8136; Dev Precision: 0.2857; Dev Recall: 0.2500; Dev Loss:0.1573
Epoch 8: Dev Accuracy: 0.7627; Dev Precision: 0.2000; Dev Recall: 0.2500; Dev Loss:0.1574
Epoch 9: Dev Accuracy: 0.7288; Dev Precision: 0.1667; Dev Recall: 0.2500; Dev Loss:0.1580
--------------------

Architecture #138
--------------------
Epoch 0: Dev Accuracy: 0.8644; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.1581
Epoch 1: Dev Accuracy: 0.5763; Dev Precision: 0.2069; Dev Recall: 0.7500; Dev Loss:0.1837
Epoch 2: Dev Accuracy: 0.7458; Dev Precision: 0.1818; Dev Recall: 0.2500; Dev Loss:0.2087
Epoch 3: Dev Accuracy: 0.6949; Dev Precision: 0.1875; Dev Recall: 0.3750; Dev Loss:0.2128
Epoch 4: Dev Accuracy: 0.7458; Dev Precision: 0.1818; Dev Recall: 0.2500; Dev Loss:0.2261
Epoch 5: Dev Accuracy: 0.6610; Dev Precisi

Epoch 0: Dev Accuracy: 0.8644; Dev Precision: 0.5000; Dev Recall: 0.1250; Dev Loss:0.1586
Epoch 1: Dev Accuracy: 0.8305; Dev Precision: 0.2500; Dev Recall: 0.1250; Dev Loss:0.1590
Epoch 2: Dev Accuracy: 0.8136; Dev Precision: 0.2000; Dev Recall: 0.1250; Dev Loss:0.1592
Epoch 3: Dev Accuracy: 0.8305; Dev Precision: 0.2500; Dev Recall: 0.1250; Dev Loss:0.1587
Epoch 4: Dev Accuracy: 0.7797; Dev Precision: 0.1429; Dev Recall: 0.1250; Dev Loss:0.1589
Epoch 5: Dev Accuracy: 0.7458; Dev Precision: 0.1818; Dev Recall: 0.2500; Dev Loss:0.1595
Epoch 6: Dev Accuracy: 0.7458; Dev Precision: 0.1818; Dev Recall: 0.2500; Dev Loss:0.1598
Epoch 7: Dev Accuracy: 0.6610; Dev Precision: 0.1667; Dev Recall: 0.3750; Dev Loss:0.1600
Epoch 8: Dev Accuracy: 0.6102; Dev Precision: 0.1429; Dev Recall: 0.3750; Dev Loss:0.1609
Epoch 9: Dev Accuracy: 0.4746; Dev Precision: 0.1290; Dev Recall: 0.5000; Dev Loss:0.1626
--------------------

Architecture #147
--------------------
Epoch 0: Dev Accuracy: 0.1525; Dev Prec

Epoch 6: Dev Accuracy: 0.7627; Dev Precision: 0.2000; Dev Recall: 0.2500; Dev Loss:0.1715
Epoch 7: Dev Accuracy: 0.7627; Dev Precision: 0.2000; Dev Recall: 0.2500; Dev Loss:0.1767
Epoch 8: Dev Accuracy: 0.7966; Dev Precision: 0.2500; Dev Recall: 0.2500; Dev Loss:0.1870
Epoch 9: Dev Accuracy: 0.7797; Dev Precision: 0.1429; Dev Recall: 0.1250; Dev Loss:0.2720
--------------------

Architecture #155
--------------------
Epoch 0: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1680
Epoch 1: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1677
Epoch 2: Dev Accuracy: 0.1186; Dev Precision: 0.1207; Dev Recall: 0.8750; Dev Loss:0.1664
Epoch 3: Dev Accuracy: 0.1864; Dev Precision: 0.1296; Dev Recall: 0.8750; Dev Loss:0.1655
Epoch 4: Dev Accuracy: 0.2881; Dev Precision: 0.1458; Dev Recall: 0.8750; Dev Loss:0.1646
Epoch 5: Dev Accuracy: 0.4407; Dev Precision: 0.1795; Dev Recall: 0.8750; Dev Loss:0.1639
Epoch 6: Dev Accuracy: 0.6102; Dev Prec

Epoch 1: Dev Accuracy: 0.7288; Dev Precision: 0.1000; Dev Recall: 0.1250; Dev Loss:0.1577
Epoch 2: Dev Accuracy: 0.6102; Dev Precision: 0.2000; Dev Recall: 0.6250; Dev Loss:0.1614
Epoch 3: Dev Accuracy: 0.3898; Dev Precision: 0.1818; Dev Recall: 1.0000; Dev Loss:0.1727
Epoch 4: Dev Accuracy: 0.5254; Dev Precision: 0.1429; Dev Recall: 0.5000; Dev Loss:0.1633
Epoch 5: Dev Accuracy: 0.3390; Dev Precision: 0.1556; Dev Recall: 0.8750; Dev Loss:0.1751
Epoch 6: Dev Accuracy: 0.6271; Dev Precision: 0.1500; Dev Recall: 0.3750; Dev Loss:0.1607
Epoch 7: Dev Accuracy: 0.7119; Dev Precision: 0.2353; Dev Recall: 0.5000; Dev Loss:0.1532
Epoch 8: Dev Accuracy: 0.7288; Dev Precision: 0.2143; Dev Recall: 0.3750; Dev Loss:0.1510
Epoch 9: Dev Accuracy: 0.6271; Dev Precision: 0.1818; Dev Recall: 0.5000; Dev Loss:0.1593
--------------------

Architecture #164
--------------------
Epoch 0: Dev Accuracy: 0.2203; Dev Precision: 0.1200; Dev Recall: 0.7500; Dev Loss:0.1650
Epoch 1: Dev Accuracy: 0.2712; Dev Prec

Epoch 7: Dev Accuracy: 0.7458; Dev Precision: 0.2308; Dev Recall: 0.3750; Dev Loss:0.2626
Epoch 8: Dev Accuracy: 0.7797; Dev Precision: 0.1429; Dev Recall: 0.1250; Dev Loss:0.2729
Epoch 9: Dev Accuracy: 0.7458; Dev Precision: 0.2308; Dev Recall: 0.3750; Dev Loss:0.2588
--------------------

Architecture #172
--------------------
Epoch 0: Dev Accuracy: 0.7797; Dev Precision: 0.2727; Dev Recall: 0.3750; Dev Loss:0.1592
Epoch 1: Dev Accuracy: 0.2373; Dev Precision: 0.1373; Dev Recall: 0.8750; Dev Loss:0.1804
Epoch 2: Dev Accuracy: 0.7458; Dev Precision: 0.2941; Dev Recall: 0.6250; Dev Loss:0.1549
Epoch 3: Dev Accuracy: 0.6271; Dev Precision: 0.1818; Dev Recall: 0.5000; Dev Loss:0.1657
Epoch 4: Dev Accuracy: 0.6610; Dev Precision: 0.2273; Dev Recall: 0.6250; Dev Loss:0.1667
Epoch 5: Dev Accuracy: 0.6102; Dev Precision: 0.2222; Dev Recall: 0.7500; Dev Loss:0.1845
Epoch 6: Dev Accuracy: 0.6949; Dev Precision: 0.1875; Dev Recall: 0.3750; Dev Loss:0.1662
Epoch 7: Dev Accuracy: 0.5254; Dev Prec

Epoch 2: Dev Accuracy: 0.7458; Dev Precision: 0.1818; Dev Recall: 0.2500; Dev Loss:0.1775
Epoch 3: Dev Accuracy: 0.7966; Dev Precision: 0.2500; Dev Recall: 0.2500; Dev Loss:0.2101
Epoch 4: Dev Accuracy: 0.7458; Dev Precision: 0.2308; Dev Recall: 0.3750; Dev Loss:0.2428
Epoch 5: Dev Accuracy: 0.6949; Dev Precision: 0.1875; Dev Recall: 0.3750; Dev Loss:0.3513
Epoch 6: Dev Accuracy: 0.8475; Dev Precision: 0.3333; Dev Recall: 0.1250; Dev Loss:0.3551
Epoch 7: Dev Accuracy: 0.8305; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.3202
Epoch 8: Dev Accuracy: 0.6102; Dev Precision: 0.1429; Dev Recall: 0.3750; Dev Loss:0.3955
Epoch 9: Dev Accuracy: 0.7458; Dev Precision: 0.1818; Dev Recall: 0.2500; Dev Loss:0.3879
--------------------

Architecture #181
--------------------
Epoch 0: Dev Accuracy: 0.6271; Dev Precision: 0.2308; Dev Recall: 0.7500; Dev Loss:0.1614
Epoch 1: Dev Accuracy: 0.6610; Dev Precision: 0.2500; Dev Recall: 0.7500; Dev Loss:0.1565
Epoch 2: Dev Accuracy: 0.7288; Dev Prec

Epoch 8: Dev Accuracy: 0.3051; Dev Precision: 0.1333; Dev Recall: 0.7500; Dev Loss:0.1654
Epoch 9: Dev Accuracy: 0.4915; Dev Precision: 0.1765; Dev Recall: 0.7500; Dev Loss:0.1636
--------------------

Architecture #189
--------------------
Epoch 0: Dev Accuracy: 0.7627; Dev Precision: 0.2000; Dev Recall: 0.2500; Dev Loss:0.1613
Epoch 1: Dev Accuracy: 0.6780; Dev Precision: 0.2105; Dev Recall: 0.5000; Dev Loss:0.1582
Epoch 2: Dev Accuracy: 0.6441; Dev Precision: 0.1579; Dev Recall: 0.3750; Dev Loss:0.1636
Epoch 3: Dev Accuracy: 0.7119; Dev Precision: 0.0909; Dev Recall: 0.1250; Dev Loss:0.1487
Epoch 4: Dev Accuracy: 0.6441; Dev Precision: 0.0667; Dev Recall: 0.1250; Dev Loss:0.1763
Epoch 5: Dev Accuracy: 0.5254; Dev Precision: 0.1154; Dev Recall: 0.3750; Dev Loss:0.2063
Epoch 6: Dev Accuracy: 0.8136; Dev Precision: 0.2000; Dev Recall: 0.1250; Dev Loss:0.1545
Epoch 7: Dev Accuracy: 0.7627; Dev Precision: 0.2500; Dev Recall: 0.3750; Dev Loss:0.1502
Epoch 8: Dev Accuracy: 0.8305; Dev Prec

Epoch 3: Dev Accuracy: 0.6271; Dev Precision: 0.1111; Dev Recall: 0.2500; Dev Loss:0.1581
Epoch 4: Dev Accuracy: 0.6441; Dev Precision: 0.1579; Dev Recall: 0.3750; Dev Loss:0.1581
Epoch 5: Dev Accuracy: 0.6441; Dev Precision: 0.1176; Dev Recall: 0.2500; Dev Loss:0.1575
Epoch 6: Dev Accuracy: 0.6271; Dev Precision: 0.1111; Dev Recall: 0.2500; Dev Loss:0.1572
Epoch 7: Dev Accuracy: 0.6102; Dev Precision: 0.1429; Dev Recall: 0.3750; Dev Loss:0.1579
Epoch 8: Dev Accuracy: 0.6271; Dev Precision: 0.1818; Dev Recall: 0.5000; Dev Loss:0.1573
Epoch 9: Dev Accuracy: 0.6610; Dev Precision: 0.1667; Dev Recall: 0.3750; Dev Loss:0.1543
--------------------

Architecture #198
--------------------
Epoch 0: Dev Accuracy: 0.4068; Dev Precision: 0.1351; Dev Recall: 0.6250; Dev Loss:0.1857
Epoch 1: Dev Accuracy: 0.7119; Dev Precision: 0.0909; Dev Recall: 0.1250; Dev Loss:0.1580
Epoch 2: Dev Accuracy: 0.6441; Dev Precision: 0.1905; Dev Recall: 0.5000; Dev Loss:0.2632
Epoch 3: Dev Accuracy: 0.7627; Dev Prec

Epoch 9: Dev Accuracy: 0.6441; Dev Precision: 0.1579; Dev Recall: 0.3750; Dev Loss:0.1567
--------------------

Architecture #206
--------------------
Epoch 0: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1738
Epoch 1: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1707
Epoch 2: Dev Accuracy: 0.1356; Dev Precision: 0.1228; Dev Recall: 0.8750; Dev Loss:0.1658
Epoch 3: Dev Accuracy: 0.4407; Dev Precision: 0.1429; Dev Recall: 0.6250; Dev Loss:0.1633
Epoch 4: Dev Accuracy: 0.5424; Dev Precision: 0.1724; Dev Recall: 0.6250; Dev Loss:0.1618
Epoch 5: Dev Accuracy: 0.6271; Dev Precision: 0.2083; Dev Recall: 0.6250; Dev Loss:0.1617
Epoch 6: Dev Accuracy: 0.5085; Dev Precision: 0.1613; Dev Recall: 0.6250; Dev Loss:0.1619
Epoch 7: Dev Accuracy: 0.4915; Dev Precision: 0.1562; Dev Recall: 0.6250; Dev Loss:0.1624
Epoch 8: Dev Accuracy: 0.6441; Dev Precision: 0.1905; Dev Recall: 0.5000; Dev Loss:0.1599
Epoch 9: Dev Accuracy: 0.5085; Dev Prec

Epoch 4: Dev Accuracy: 0.7288; Dev Precision: 0.2143; Dev Recall: 0.3750; Dev Loss:0.1955
Epoch 5: Dev Accuracy: 0.6949; Dev Precision: 0.1429; Dev Recall: 0.2500; Dev Loss:0.2319
Epoch 6: Dev Accuracy: 0.6949; Dev Precision: 0.1429; Dev Recall: 0.2500; Dev Loss:0.2435
Epoch 7: Dev Accuracy: 0.7458; Dev Precision: 0.1111; Dev Recall: 0.1250; Dev Loss:0.2209
Epoch 8: Dev Accuracy: 0.6949; Dev Precision: 0.1429; Dev Recall: 0.2500; Dev Loss:0.2467
Epoch 9: Dev Accuracy: 0.7288; Dev Precision: 0.1667; Dev Recall: 0.2500; Dev Loss:0.2933
--------------------

Architecture #215
--------------------
Epoch 0: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1641
Epoch 1: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1644
Epoch 2: Dev Accuracy: 0.4068; Dev Precision: 0.1860; Dev Recall: 1.0000; Dev Loss:0.1631
Epoch 3: Dev Accuracy: 0.5085; Dev Precision: 0.2162; Dev Recall: 1.0000; Dev Loss:0.1626
Epoch 4: Dev Accuracy: 0.4915; Dev Prec

Epoch 0: Dev Accuracy: 0.7458; Dev Precision: 0.1818; Dev Recall: 0.2500; Dev Loss:0.1575
Epoch 1: Dev Accuracy: 0.7458; Dev Precision: 0.1818; Dev Recall: 0.2500; Dev Loss:0.1552
Epoch 2: Dev Accuracy: 0.4407; Dev Precision: 0.1795; Dev Recall: 0.8750; Dev Loss:0.1684
Epoch 3: Dev Accuracy: 0.3051; Dev Precision: 0.1489; Dev Recall: 0.8750; Dev Loss:0.1699
Epoch 4: Dev Accuracy: 0.7288; Dev Precision: 0.1000; Dev Recall: 0.1250; Dev Loss:0.1493
Epoch 5: Dev Accuracy: 0.2712; Dev Precision: 0.1429; Dev Recall: 0.8750; Dev Loss:0.1722
Epoch 6: Dev Accuracy: 0.4746; Dev Precision: 0.1714; Dev Recall: 0.7500; Dev Loss:0.1711
Epoch 7: Dev Accuracy: 0.7288; Dev Precision: 0.1000; Dev Recall: 0.1250; Dev Loss:0.1544
Epoch 8: Dev Accuracy: 0.6780; Dev Precision: 0.1333; Dev Recall: 0.2500; Dev Loss:0.1568
Epoch 9: Dev Accuracy: 0.5763; Dev Precision: 0.1304; Dev Recall: 0.3750; Dev Loss:0.1565
--------------------

Architecture #224
--------------------
Epoch 0: Dev Accuracy: 0.8644; Dev Prec

Epoch 6: Dev Accuracy: 0.8136; Dev Precision: 0.2000; Dev Recall: 0.1250; Dev Loss:0.2111
Epoch 7: Dev Accuracy: 0.8475; Dev Precision: 0.3333; Dev Recall: 0.1250; Dev Loss:0.2350
Epoch 8: Dev Accuracy: 0.6610; Dev Precision: 0.1667; Dev Recall: 0.3750; Dev Loss:0.2677
Epoch 9: Dev Accuracy: 0.8136; Dev Precision: 0.2857; Dev Recall: 0.2500; Dev Loss:0.2634
--------------------

Architecture #232
--------------------
Epoch 0: Dev Accuracy: 0.7119; Dev Precision: 0.2353; Dev Recall: 0.5000; Dev Loss:0.1554
Epoch 1: Dev Accuracy: 0.6102; Dev Precision: 0.2222; Dev Recall: 0.7500; Dev Loss:0.1592
Epoch 2: Dev Accuracy: 0.4237; Dev Precision: 0.1579; Dev Recall: 0.7500; Dev Loss:0.1812
Epoch 3: Dev Accuracy: 0.7458; Dev Precision: 0.1818; Dev Recall: 0.2500; Dev Loss:0.1468
Epoch 4: Dev Accuracy: 0.7119; Dev Precision: 0.2353; Dev Recall: 0.5000; Dev Loss:0.1569
Epoch 5: Dev Accuracy: 0.7797; Dev Precision: 0.2727; Dev Recall: 0.3750; Dev Loss:0.1596
Epoch 6: Dev Accuracy: 0.7288; Dev Prec

## Save results of model training

In [6]:
SAVE_PREFIX = '../results/SimpleNNCustom_'
df.to_csv(f'{SAVE_PREFIX}models.csv')
torch.save(best_acc[0], f'{SAVE_PREFIX}best_acc.pt')
torch.save(best_rec[0], f'{SAVE_PREFIX}best_rec.pt')
torch.save(best_prec[0], f'{SAVE_PREFIX}best_prec.pt')