In [1]:
!pip install wget
!pip install textblob
!pip install pyspellchecker
!pip install nltk
!pip install tqdm

Collecting wget
  Downloading wget-3.2.zip (10 kB)
Building wheels for collected packages: wget
  Building wheel for wget (setup.py) ... [?25l[?25hdone
  Created wheel for wget: filename=wget-3.2-py3-none-any.whl size=9675 sha256=ac4afc512847715e94ef1df29a997b7afc70e479886ac36a1fc5a25aa445d7b1
  Stored in directory: /root/.cache/pip/wheels/a1/b6/7c/0e63e34eb06634181c63adacca38b79ff8f35c37e3c13e3c02
Successfully built wget
Installing collected packages: wget
Successfully installed wget-3.2
Collecting pyspellchecker
  Downloading pyspellchecker-0.6.3-py3-none-any.whl (2.7 MB)
[K     |████████████████████████████████| 2.7 MB 4.6 MB/s 
[?25hInstalling collected packages: pyspellchecker
Successfully installed pyspellchecker-0.6.3


In [2]:
import torch
import torch.nn as nn 
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as torch_data

In [3]:
import numpy as np
import wget
import json

import nltk
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
from nltk.corpus import stopwords

from textblob import TextBlob 
import re

from spellchecker import SpellChecker

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Unzipping corpora/wordnet.zip.
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Unzipping taggers/averaged_perceptron_tagger.zip.


In [5]:
from torch import cuda

seed = 4814

if cuda.is_available():
  device = 'cuda'
  torch.cuda.manual_seed_all(seed)
else:
  print('running on cpu')
  device = 'cpu'

torch.cuda.get_device_name(0)

'Tesla K80'

In [6]:
data = wget.download('https://raw.githubusercontent.com/sjtuprog/fox-news-comments/master/fox-news-comments.json')

modern_data = wget.download('https://raw.githubusercontent.com/michealman114/NLP_Hate_Speech_Detection/main/modern_comments.json')

print(data)
print(modern_data)

fox-news-comments.json
modern_comments.json


In [7]:
!ls

fox-news-comments.json	modern_comments.json  sample_data


In [8]:
train_lines = open("fox-news-comments.json", "r").readlines()
test_lines = open("modern_comments.json", "r").readlines()

In [9]:
stop_words = set(stopwords.words('english'))
stop_words.add('')

from nltk.tokenize import word_tokenize

In [13]:
def clean(data):
  max_len = 0
  max_title_len = 0  
  text_list = []
  title_list = []
  label = []
  for i in data:
    comment = json.loads(i)
    t = comment['text']
    t = ' '.join([x for x in t.split() if x[0] != '@'])
    t = ' '.join(re.findall("[a-zA-Z,.]+",t))
    t = t.replace(',', ' ')
    t = t.replace('.', ' ')
    text = word_tokenize(t)
    text = [x for x in text if x.lower() not in stop_words]
    max_len = max(max_len, len(text))
    text_list.append(text)
    title = comment['title']
    title = title.replace(',', '')
    title = title.replace('.', '')
    title = re.findall("[a-zA-Z,.]+",title)
    title_list.append(title)
    max_title_len = max(max_title_len, len(title))
    label.append(comment['label'])
  
  labels = np.array(label)
  return text_list, labels, title_list, max_len, max_title_len

train_text, train_labels, train_title, train_max_len, train_max_title_len = clean(train_lines)
test_text, test_labels, test_title, test_max_len, test_max_title_len = clean(test_lines)


In [12]:
spell = SpellChecker()
spell.word_frequency.add('obama')
spell.word_frequency.add('blm')
spell.word_frequency.add('killing')

In [16]:
import gensim.downloader as api
path = api.load("word2vec-google-news-300", return_path=True)
print(path)

/root/gensim-data/word2vec-google-news-300/word2vec-google-news-300.gz


In [17]:
import gensim
embed = gensim.models.KeyedVectors.load_word2vec_format(path, binary=True)

In [19]:
def get_embed(word):
  x = np.zeros((300,)) # default value should be 0
  corrected = spell.correction(word) # closest correction
  if word in embed: # base word
    x = embed[word]
  elif word.upper() in embed: # capitalized (edge case for acronyms like BLM) (for some reason blm doesn't exist but BLM does?)
    x = embed[word.upper()]
  elif word.lower() in embed: # opposite of capitalization
    x = embed[word.lower()]
  elif corrected in embed: # last case, check if closest correction exists (might be bad, some corrections are kinda ass)
    x = embed[corrected]
  
  return x

In [20]:
print('comments',train_max_len)
print('title',train_max_title_len)
print(len(train_text))
print(len(train_title))
print(train_labels.shape)

comments 244
title 13
1528
1528
(1528,)


In [21]:
print('test_comments', test_max_len)
print('test_title', test_max_title_len)
print(len(test_text))
print(len(test_title))
print(test_labels.shape)

test_comments 126
test_title 19
102
102
(102,)


In [22]:
def to_array(comments, titles, max_comment_len, max_title_len):
  data_array = np.zeros((max_comment_len, len(comments), 300))
  title_array = np.zeros((max_title_len, len(titles), 300))
  for ix1, sent in enumerate(comments):
    for ix2, word in enumerate(sent):
      data_array[ix2,ix1] = get_embed(word)
  for ix1, title in enumerate(titles):
    for ix2, word in enumerate(title):
      title_array[ix2,ix1] = get_embed(word)
  
  return data_array, title_array

train_data_array, train_title_array = to_array(train_text, train_title, train_max_len, train_max_title_len)
test_data_array, test_title_array = to_array(test_text, test_title, test_max_len, test_max_title_len)

In [23]:
print(train_data_array.shape, train_title_array.shape)
print(test_data_array.shape, test_title_array.shape)

(244, 1528, 300) (13, 1528, 300)
(126, 102, 300) (19, 102, 300)


In [24]:
"""
data_array.shape = (244, 1528, 300)
data_labels.shape = (1528,)
data is an (L,N,D) array
L = max_length of sequence
N = batch_size
D = embed_dim
"""

#Source: https://likegeeks.com/numpy-shuffle/#Shuffle_multiple_NumPy_arrays_together
def custom_shuffle(data,titles,labels):
    _, num_samples, _ = data.shape
    shuffled_indices = np.random.permutation(num_samples) #return a permutation of the indices
    new_data = data[:,shuffled_indices,:]
    new_labels = labels[shuffled_indices]
    new_titles = titles[:,shuffled_indices,:]

    return (new_data, new_titles, new_labels)
    
train_data_array,train_title_array,train_labels = custom_shuffle(train_data_array,train_title_array,train_labels)

test_data_array, test_title_array, test_labels = custom_shuffle(test_data_array, test_title_array, test_labels)

In [25]:
print(train_data_array[:,0,:].shape)

(244, 300)


In [26]:
print(train_data_array.dtype)
print(type(train_data_array))
print(train_title_array.dtype)
print(type(train_title_array))

data_array = np.float32(train_data_array)
title_array = np.float32(train_title_array)

test_data_array = np.float32(test_data_array)
test_title_array = np.float32(test_title_array)

float64
<class 'numpy.ndarray'>
float64
<class 'numpy.ndarray'>


In [None]:
class BaseModel(nn.Module): # single direction lstm, no attention
  def __init__(self, hidden_size = 100, embed_dim = 300):
    super(BaseModel, self).__init__()
    
    self.hidden_size = hidden_size
    #self.embedding = embed
    
    self.linear1 = nn.Linear(hidden_size, 150) # map context vector to value
    self.linear2 = nn.Linear(150, 1)

    #self.attention1 = nn.Linear(2*hidden_size, 50) # map hidden state vector to value
    #self.attention2 = nn.Linear(50, 1)

    self.relu = nn.ReLU()

    #self.sm = nn.Softmax(dim = 0)
    self.sigmoid = nn.Sigmoid()
    
    self.lstm = nn.LSTM(input_size = embed_dim, hidden_size = hidden_size, num_layers = 1, batch_first = False, dropout = 0.2, bidirectional = False)

  def forward(self, data):
    """
    data is an (L,N,D) array
    L = max_length of sequence
    N = batch_size
    D = embed_dim
    returns an (N,1) array of probabilities that each comment is hateful
    """
    #print(type(data))
    #print(data.dtype)

    #print("data size " + str(data.size()))
    hidden_states, (_, _) = self.lstm(data) # (L,N,H) array
    
    sentences = torch.sum(hidden_states, axis = 0)

    return self.sigmoid(torch.squeeze(self.linear2(self.relu(self.linear1(sentences)))))

In [None]:
class BidiModel(nn.Module): # Bidi
  def __init__(self, hidden_size = 100, embed_dim = 300):
    super().__init__()
    
    self.hidden_size = hidden_size
    #self.embedding = embed
    
    self.linear1 = nn.Linear(2*hidden_size, hidden_size) # map context vector to value
    self.linear2 = nn.Linear(hidden_size, 1)

    #self.attention1 = nn.Linear(2*hidden_size, 50) # map hidden state vector to value
    #self.attention2 = nn.Linear(50, 1)

    self.relu = nn.ReLU()

    #self.sm = nn.Softmax(dim = 0)
    self.sigmoid = nn.Sigmoid()
    
    self.lstm = nn.LSTM(input_size = embed_dim, hidden_size = hidden_size, num_layers = 1, batch_first = False, dropout = 0.2, bidirectional = True)

  def forward(self, data):
    """
    data is an (L,N,D) array
    L = max_length of sequence
    N = batch_size
    D = embed_dim
    returns an (N,1) array of probabilities that each comment is hateful
    """
    #print(type(data))
    #print(data.dtype)

    #print("data size " + str(data.size()))
    hidden_states, (_, _) = self.lstm(data) # (L,N,2H) array
    #print("HIDDEN STATES SIZE " + str(hidden_states.size()))
    #weights = self.attention2(self.relu(self.attention1(hidden_states))) # (L,N,1) array
    #print("WEIGHTS SIZE " + str(weights.size()))
    # compress hidden states into (N,2H) array
    
    #alpha = self.sm(weights.reshape(weights.shape[:-1])) # weights
    #print("ALPHA SIZE " + str(alpha.size()))

    #hidden_states = torch.moveaxis(hidden_states, -1, 0)
    #hidden_states = torch.squeeze(hidden_states)

    #print("HIDDEN STATES SIZE " + str(hidden_states.size()))

    #sentences = torch.sum(hidden_states * alpha, axis = 1)

    #print("sentences " + str(sentences.size()))
    #sentences = torch.moveaxis(sentences, 0, -1)
    #print("sentences " + str(sentences.size()))
    sentences = torch.sum(hidden_states, axis = 0)

    return self.sigmoid(torch.squeeze(self.linear2(self.relu(self.linear1(sentences)))))

In [None]:
class FullModel(nn.Module): # bidi with attention
  def __init__(self, hidden_size = 100, embed_dim = 300):
    super().__init__()
    
    self.hidden_size = hidden_size
    #self.embedding = embed
    
    self.linear1 = nn.Linear(2*hidden_size, hidden_size) # map context vector to value
    self.linear2 = nn.Linear(hidden_size, 1)

    self.attention1 = nn.Linear(2*hidden_size, 50) # map hidden state vector to value
    self.attention2 = nn.Linear(50, 1)

    self.relu = nn.ReLU()

    self.sm = nn.Softmax(dim = 0)
    self.sigmoid = nn.Sigmoid()
    
    self.lstm = nn.LSTM(input_size = embed_dim, hidden_size = hidden_size, num_layers = 1, batch_first = False, dropout = 0.2, bidirectional = True)

  def forward(self, data):
    """
    data is an (L,N,D) array
    L = max_length of sequence
    N = batch_size
    D = embed_dim
    returns an (N,1) array of probabilities that each comment is hateful
    """
    #print(type(data))
    #print(data.dtype)

    #print("data size " + str(data.size()))
    hidden_states, (_, _) = self.lstm(data) # (L,N,2H) array
    #print("HIDDEN STATES SIZE " + str(hidden_states.size()))
    weights = self.attention2(self.relu(self.attention1(hidden_states))) # (L,N,1) array
    #print("WEIGHTS SIZE " + str(weights.size()))
    # compress hidden states into (N,2H) array
    
    alpha = self.sm(weights.reshape(weights.shape[:-1])) # weights
    #print("ALPHA SIZE " + str(alpha.size()))

    hidden_states = torch.moveaxis(hidden_states, -1, 0)
    #hidden_states = torch.squeeze(hidden_states)

    #print("HIDDEN STATES SIZE " + str(hidden_states.size()))

    sentences = torch.sum(hidden_states * alpha, axis = 1)

    #print("sentences " + str(sentences.size()))
    sentences = torch.moveaxis(sentences, 0, -1)
    #print("sentences " + str(sentences.size()))

    return self.sigmoid(torch.squeeze(self.linear2(self.relu(self.linear1(sentences)))))

In [None]:
class ModelWithTitle(nn.Module): # bidi with attention
  def __init__(self, hidden_size = 100, embed_dim = 300):
    super().__init__()
    
    self.hidden_size = hidden_size
    
    self.linear1 = nn.Linear(4*hidden_size, hidden_size) # map context vector to value (concatenated from parallel networks)
    self.linear2 = nn.Linear(hidden_size, 1)

    self.relu = nn.ReLU()

    self.sm = nn.Softmax(dim = 0)
    self.sigmoid = nn.Sigmoid()

    #comments
    self.attention1_comment = nn.Linear(2*hidden_size, 50) # map hidden state vector to value
    self.attention2_comment = nn.Linear(50, 1)
    
    self.lstm_comment = nn.LSTM(input_size = embed_dim, hidden_size = hidden_size, num_layers = 1, batch_first = False, dropout = 0.2, bidirectional = True)

    #titles
    self.attention1_title = nn.Linear(2*hidden_size, 50)
    self.attention2_title = nn.Linear(50, 1)

    self.lstm_title = nn.LSTM(input_size = embed_dim, hidden_size = hidden_size, num_layers = 1, batch_first = False, dropout = 0.2, bidirectional = True)

  def forward(self, comment_data, title_data):
    """
    comments is an (L1,N,D) array
    titles is an (L2,N,D) array
    L1 = max_length of sequence
    L2 = max_length of title
    N = batch_size
    D = embed_dim
    returns an (N,1) array of probabilities that each comment is hateful
    """
    #print(type(data))
    #print(data.dtype)

    #print("data size " + str(data.size()))
    hidden_states, (_, _) = self.lstm_comment(comment_data) # (L,N,2H) array

    weights = self.attention2_comment(self.relu(self.attention1_comment(hidden_states))) # (L,N,1) array
    #print("WEIGHTS SIZE " + str(weights.size()))
    # compress hidden states into (N,2H) array
    
    alpha = self.sm(weights.reshape(weights.shape[:-1])) # weights
    #print("ALPHA SIZE " + str(alpha.size()))

    hidden_states = torch.moveaxis(hidden_states, -1, 0)
    #hidden_states = torch.squeeze(hidden_states)

    #print("HIDDEN STATES SIZE " + str(hidden_states.size()))

    sentences = torch.sum(hidden_states * alpha, axis = 1)

    #print("sentences " + str(sentences.size()))
    sentences = torch.moveaxis(sentences, 0, -1) # sentences is N x 2*hidden_size
    #print("sentences " + str(sentences.size()))

    hidden_states, (_,_) = self.lstm_title(title_data)
    weights = self.attention2_title(self.relu(self.attention1_title(hidden_states)))
    alpha = self.sm(weights.reshape(weights.shape[:-1]))
    hidden_states = torch.moveaxis(hidden_states, -1, 0)
    titles = torch.sum(hidden_states * alpha, axis = 1)
    titles = torch.moveaxis(titles, 0, -1) # titles is N x 2*hidden_size

    result = torch.cat((sentences, titles), dim = 1)

    return self.sigmoid(torch.squeeze(self.linear2(self.relu(self.linear1(result)))))

In [None]:
class Dataset(torch.utils.data.Dataset):
  'Characterizes a dataset for PyTorch'
  def __init__(self, list_IDs, data, labels, titles = None):
        # initialization now works with titles, passes in optional title information
        # works the same as before, but now gets title data if you give it to it
        'Initialization'
        self.data = data
        self.titles = titles
        self.labels = labels
        self.list_IDs = list_IDs

  def __len__(self):
        'Denotes the total number of samples'
        return len(self.list_IDs)

  def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        ID = self.list_IDs[index]

        # Load data and get label
        X = self.data[:,ID,:]
        y = self.labels[ID]

        if self.titles is not None:
          t = self.titles[:,ID,:]
          return X, t, y
        return X, y

In [None]:
from tqdm import tqdm

def train(data, labels, n_epochs, batch_size, modeltype, model = None):
    device = torch.device('cuda')  # run on colab gpu

    if model is None:
        model = modeltype().to(device)
        
    opt = optim.Adam(model.parameters(), lr=0.001)

    training_data = Dataset(range(len(labels)), data, labels)

    loader = torch_data.DataLoader(training_data, batch_size=batch_size, shuffle=True)

    loss_fn = nn.BCELoss()

    losses = []
    for epoch in tqdm(range(n_epochs)):
        epoch_loss = 0
        for context, label in loader:
            context = context.to(device)
            context = context.moveaxis(0, 1)
            label = label.to(device).type(torch.float32)

            preds = model.forward(context)
            #print(preds)
            #print(label)

            opt.zero_grad()
            loss = loss_fn(preds, label)
            loss.backward()
            opt.step()

            epoch_loss += loss.item()
        print('Loss:', epoch_loss)
        losses.append(epoch_loss)

    print(losses)
    return model

In [None]:
def train_with_titles(data, titles, labels, n_epochs, batch_size, model = None):
    device = torch.device('cuda')  # run on colab gpu

    if model is None:
        model = ModelWithTitle().to(device)
        
    opt = optim.Adam(model.parameters(), lr=0.001)

    training_data = Dataset(range(len(labels)), data, labels, titles = titles)

    loader = torch_data.DataLoader(training_data, batch_size=batch_size, shuffle=True)

    loss_fn = nn.BCELoss()

    losses = []
    for epoch in tqdm(range(n_epochs)):
        epoch_loss = 0
        for context, t, label in loader:
            context = context.to(device)
            context = context.moveaxis(0, 1)
            t = t.to(device)
            t = t.moveaxis(0,1)
            label = label.to(device).type(torch.float32)

            preds = model.forward(context, t)
            #print(preds)
            #print(label)

            opt.zero_grad()
            loss = loss_fn(preds, label)
            loss.backward()
            opt.step()

            epoch_loss += loss.item()
        print('Loss:', epoch_loss)
        losses.append(epoch_loss)

    print(losses)
    return model

In [None]:
print(data_array.shape)
model = train_with_titles(train_data_array, train_title_array, train_labels, 30, 128)

(244, 1528, 300)


  "num_layers={}".format(dropout, num_layers))
  3%|▎         | 1/30 [00:00<00:27,  1.04it/s]

Loss: 7.5037314891815186


  7%|▋         | 2/30 [00:01<00:22,  1.26it/s]

Loss: 7.120764970779419


 10%|█         | 3/30 [00:02<00:19,  1.36it/s]

Loss: 6.922450661659241


 13%|█▎        | 4/30 [00:02<00:18,  1.41it/s]

Loss: 6.739884376525879


 17%|█▋        | 5/30 [00:03<00:17,  1.43it/s]

Loss: 6.302202552556992


 20%|██        | 6/30 [00:04<00:16,  1.45it/s]

Loss: 5.9617171585559845


 23%|██▎       | 7/30 [00:05<00:15,  1.46it/s]

Loss: 5.701364636421204


 27%|██▋       | 8/30 [00:05<00:15,  1.47it/s]

Loss: 5.525161385536194


 30%|███       | 9/30 [00:06<00:14,  1.47it/s]

Loss: 5.2886203825473785


 33%|███▎      | 10/30 [00:07<00:13,  1.48it/s]

Loss: 5.08858123421669


 37%|███▋      | 11/30 [00:07<00:12,  1.48it/s]

Loss: 4.631946474313736


 40%|████      | 12/30 [00:08<00:12,  1.48it/s]

Loss: 4.190413594245911


 43%|████▎     | 13/30 [00:09<00:11,  1.48it/s]

Loss: 3.884510189294815


 47%|████▋     | 14/30 [00:09<00:10,  1.48it/s]

Loss: 3.333829089999199


 50%|█████     | 15/30 [00:10<00:10,  1.48it/s]

Loss: 2.9280106127262115


 53%|█████▎    | 16/30 [00:11<00:09,  1.48it/s]

Loss: 2.631470948457718


 57%|█████▋    | 17/30 [00:11<00:08,  1.48it/s]

Loss: 2.2749634608626366


 60%|██████    | 18/30 [00:12<00:08,  1.48it/s]

Loss: 2.115672841668129


 63%|██████▎   | 19/30 [00:13<00:07,  1.48it/s]

Loss: 2.00664534419775


 67%|██████▋   | 20/30 [00:13<00:06,  1.49it/s]

Loss: 1.9073855429887772


 70%|███████   | 21/30 [00:14<00:06,  1.49it/s]

Loss: 1.7148129791021347


 73%|███████▎  | 22/30 [00:15<00:05,  1.49it/s]

Loss: 1.609808400273323


 77%|███████▋  | 23/30 [00:15<00:04,  1.49it/s]

Loss: 1.5555122345685959


 80%|████████  | 24/30 [00:16<00:04,  1.49it/s]

Loss: 1.5980764739215374


 83%|████████▎ | 25/30 [00:17<00:03,  1.49it/s]

Loss: 1.4204010590910912


 87%|████████▋ | 26/30 [00:17<00:02,  1.49it/s]

Loss: 1.395422376692295


 90%|█████████ | 27/30 [00:18<00:02,  1.49it/s]

Loss: 1.3622687235474586


 93%|█████████▎| 28/30 [00:19<00:01,  1.49it/s]

Loss: 1.3037434592843056


 97%|█████████▋| 29/30 [00:19<00:00,  1.49it/s]

Loss: 1.2793881930410862


100%|██████████| 30/30 [00:20<00:00,  1.46it/s]

Loss: 1.2601166889071465
[7.5037314891815186, 7.120764970779419, 6.922450661659241, 6.739884376525879, 6.302202552556992, 5.9617171585559845, 5.701364636421204, 5.525161385536194, 5.2886203825473785, 5.08858123421669, 4.631946474313736, 4.190413594245911, 3.884510189294815, 3.333829089999199, 2.9280106127262115, 2.631470948457718, 2.2749634608626366, 2.115672841668129, 2.00664534419775, 1.9073855429887772, 1.7148129791021347, 1.609808400273323, 1.5555122345685959, 1.5980764739215374, 1.4204010590910912, 1.395422376692295, 1.3622687235474586, 1.3037434592843056, 1.2793881930410862, 1.2601166889071465]





In [None]:
import sklearn
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

In [None]:
"""
data_array.shape = (244, 1528, 300)
data_labels.shape = (1528,)
data is an (L,N,D) array
L = max_length of sequence
N = batch_size
D = embed_dim
"""

def kfold_crossvalidation(data, labels, modeltype, k, n_epochs = 30, model=None):
    _, num_samples, _ = data.shape
    fraction = 1/k
    seg = int(num_samples * fraction)
    segment_indices = []
    for i in range(k):
        vall = i * seg
        valr = i * seg + seg
        segment_indices.append(list(range(vall,valr)))
    
    all_preds = []
    all_labels = []
    #actually run the ith split
    for i in range(k):
        train_indices = []
        test_indices = segment_indices[i]
        for j in range(k):
            if j != i:
                train_indices.extend(segment_indices[j])

    
        train_data = data[:,train_indices,:]
        train_labels = labels[train_indices]

        test_data = data[:,test_indices,:]
        test_labels = labels[test_indices]

        batch_size = 128
        model_i = train(train_data, train_labels, n_epochs, batch_size, modeltype)

        iter_loss, (y_pred, y_true) = test_model(model_i,test_data,test_labels)
        y_pred = torch.round(y_pred).cpu().detach().numpy()
        all_preds.append(y_pred)
        all_labels.append(y_true)
        #print(type(y_true), y_true.shape)
        #print(type(y_pred), y_pred.shape)
        print('Accuracy:',accuracy_score(y_true,y_pred))
        print('Precision, Recall, F1:',precision_recall_fscore_support(y_true, y_pred, average='binary'))
    
    print('\n===Aggregate Stats===')
    p = np.concatenate(all_preds, axis = None)
    l = np.concatenate(all_labels, axis = None)
    print('Accuracy:', accuracy_score(l, p))
    print('Precision, Recall, F1:', precision_recall_fscore_support(l, p, average = 'binary'))

def test_model(model, test_data, test_labels):
    test_dataset = Dataset(range(len(test_labels)), test_data, test_labels)

    test_loader = torch_data.DataLoader(test_dataset, batch_size=len(test_labels))
    loss_fn = nn.BCELoss()

    predictions = None
    
    for context, label in test_loader:
        context = context.to(device)
        context = context.moveaxis(0, 1)
        label = label.to(device).type(torch.float32)

        #preds is a tensor of roughly torch.Size([305])
        preds = model.forward(context)
        predictions = preds

        loss = loss_fn(preds, label)

        print(loss.item())
    

    return loss.item(), (preds, test_labels)


def test_split():
    #function to verify that the function was splitting the data correctly
    temp_data = np.random.rand(2,5,3)
    temp_labels = np.random.randint(2, size=5)
    print("data:\n", temp_data)
    print("labels:", temp_labels)

    kfold_crossvalidation(temp_data,temp_labels,10)

kfold_crossvalidation(train_data_array,train_labels, modeltype = FullModel, k = 10,n_epochs = 30, model = None)

  "num_layers={}".format(dropout, num_layers))
  3%|▎         | 1/30 [00:00<00:16,  1.81it/s]

Loss: 7.161470472812653


  7%|▋         | 2/30 [00:01<00:15,  1.79it/s]

Loss: 6.5919076800346375


 10%|█         | 3/30 [00:01<00:15,  1.80it/s]

Loss: 6.530954301357269


 13%|█▎        | 4/30 [00:02<00:14,  1.79it/s]

Loss: 6.49009907245636


 17%|█▋        | 5/30 [00:02<00:13,  1.79it/s]

Loss: 6.327613532543182


 20%|██        | 6/30 [00:03<00:13,  1.78it/s]

Loss: 5.976559937000275


 23%|██▎       | 7/30 [00:03<00:12,  1.78it/s]

Loss: 5.686737537384033


 27%|██▋       | 8/30 [00:04<00:12,  1.78it/s]

Loss: 5.428859829902649


 30%|███       | 9/30 [00:05<00:11,  1.78it/s]

Loss: 5.148686796426773


 33%|███▎      | 10/30 [00:05<00:11,  1.78it/s]

Loss: 4.873319834470749


 37%|███▋      | 11/30 [00:06<00:10,  1.77it/s]

Loss: 4.692749470472336


 40%|████      | 12/30 [00:06<00:10,  1.77it/s]

Loss: 4.473050951957703


 43%|████▎     | 13/30 [00:07<00:09,  1.78it/s]

Loss: 4.113484501838684


 47%|████▋     | 14/30 [00:07<00:09,  1.78it/s]

Loss: 3.8728084564208984


 50%|█████     | 15/30 [00:08<00:08,  1.78it/s]

Loss: 3.6042478382587433


 53%|█████▎    | 16/30 [00:08<00:07,  1.77it/s]

Loss: 3.3028431683778763


 57%|█████▋    | 17/30 [00:09<00:07,  1.78it/s]

Loss: 2.821358636021614


 60%|██████    | 18/30 [00:10<00:06,  1.78it/s]

Loss: 2.490732818841934


 63%|██████▎   | 19/30 [00:10<00:06,  1.78it/s]

Loss: 2.120576858520508


 67%|██████▋   | 20/30 [00:11<00:05,  1.78it/s]

Loss: 2.0415646359324455


 70%|███████   | 21/30 [00:11<00:05,  1.78it/s]

Loss: 2.187877379357815


 73%|███████▎  | 22/30 [00:12<00:04,  1.78it/s]

Loss: 1.8637932762503624


 77%|███████▋  | 23/30 [00:12<00:03,  1.77it/s]

Loss: 1.7706745639443398


 80%|████████  | 24/30 [00:13<00:03,  1.77it/s]

Loss: 1.6240072920918465


 83%|████████▎ | 25/30 [00:14<00:02,  1.77it/s]

Loss: 1.5337569415569305


 87%|████████▋ | 26/30 [00:14<00:02,  1.77it/s]

Loss: 1.4012932106852531


 90%|█████████ | 27/30 [00:15<00:01,  1.77it/s]

Loss: 1.2965248674154282


 93%|█████████▎| 28/30 [00:15<00:01,  1.77it/s]

Loss: 1.1643900200724602


 97%|█████████▋| 29/30 [00:16<00:00,  1.77it/s]

Loss: 1.1001594960689545


100%|██████████| 30/30 [00:16<00:00,  1.78it/s]

Loss: 1.0647527873516083
[7.161470472812653, 6.5919076800346375, 6.530954301357269, 6.49009907245636, 6.327613532543182, 5.976559937000275, 5.686737537384033, 5.428859829902649, 5.148686796426773, 4.873319834470749, 4.692749470472336, 4.473050951957703, 4.113484501838684, 3.8728084564208984, 3.6042478382587433, 3.3028431683778763, 2.821358636021614, 2.490732818841934, 2.120576858520508, 2.0415646359324455, 2.187877379357815, 1.8637932762503624, 1.7706745639443398, 1.6240072920918465, 1.5337569415569305, 1.4012932106852531, 1.2965248674154282, 1.1643900200724602, 1.1001594960689545, 1.0647527873516083]
1.1229056119918823
Accuracy: 0.7302631578947368
Precision, Recall, F1: (0.5428571428571428, 0.4318181818181818, 0.4810126582278481, None)



  "num_layers={}".format(dropout, num_layers))
  3%|▎         | 1/30 [00:00<00:16,  1.78it/s]

Loss: 7.252209782600403


  7%|▋         | 2/30 [00:01<00:15,  1.77it/s]

Loss: 6.5522648096084595


 10%|█         | 3/30 [00:01<00:15,  1.78it/s]

Loss: 6.520946443080902


 13%|█▎        | 4/30 [00:02<00:14,  1.78it/s]

Loss: 6.489237487316132


 17%|█▋        | 5/30 [00:02<00:14,  1.78it/s]

Loss: 6.289390981197357


 20%|██        | 6/30 [00:03<00:13,  1.78it/s]

Loss: 6.035727024078369


 23%|██▎       | 7/30 [00:03<00:12,  1.78it/s]

Loss: 5.865457147359848


 27%|██▋       | 8/30 [00:04<00:12,  1.78it/s]

Loss: 5.476056098937988


 30%|███       | 9/30 [00:05<00:11,  1.78it/s]

Loss: 5.232875674962997


 33%|███▎      | 10/30 [00:05<00:11,  1.77it/s]

Loss: 4.954723089933395


 37%|███▋      | 11/30 [00:06<00:10,  1.77it/s]

Loss: 4.778095573186874


 40%|████      | 12/30 [00:06<00:10,  1.77it/s]

Loss: 4.570939838886261


 43%|████▎     | 13/30 [00:07<00:09,  1.77it/s]

Loss: 4.380543678998947


 47%|████▋     | 14/30 [00:07<00:09,  1.77it/s]

Loss: 3.9323702454566956


 50%|█████     | 15/30 [00:08<00:08,  1.77it/s]

Loss: 3.5059328824281693


 53%|█████▎    | 16/30 [00:09<00:07,  1.77it/s]

Loss: 3.353578120470047


 57%|█████▋    | 17/30 [00:09<00:07,  1.77it/s]

Loss: 3.067207172513008


 60%|██████    | 18/30 [00:10<00:06,  1.77it/s]

Loss: 2.7971176207065582


 63%|██████▎   | 19/30 [00:10<00:06,  1.78it/s]

Loss: 2.4704372584819794


 67%|██████▋   | 20/30 [00:11<00:05,  1.78it/s]

Loss: 2.2156896740198135


 70%|███████   | 21/30 [00:11<00:05,  1.78it/s]

Loss: 2.5211153775453568


 73%|███████▎  | 22/30 [00:12<00:04,  1.78it/s]

Loss: 2.1998740285634995


 77%|███████▋  | 23/30 [00:12<00:03,  1.78it/s]

Loss: 1.9562402442097664


 80%|████████  | 24/30 [00:13<00:03,  1.78it/s]

Loss: 1.7677932605147362


 83%|████████▎ | 25/30 [00:14<00:02,  1.77it/s]

Loss: 1.5303428247570992


 87%|████████▋ | 26/30 [00:14<00:02,  1.77it/s]

Loss: 1.440201759338379


 90%|█████████ | 27/30 [00:15<00:01,  1.77it/s]

Loss: 1.2664306908845901


 93%|█████████▎| 28/30 [00:15<00:01,  1.77it/s]

Loss: 1.2097508683800697


 97%|█████████▋| 29/30 [00:16<00:00,  1.77it/s]

Loss: 1.091821638867259


100%|██████████| 30/30 [00:16<00:00,  1.77it/s]
  "num_layers={}".format(dropout, num_layers))


Loss: 1.0088547989726067
[7.252209782600403, 6.5522648096084595, 6.520946443080902, 6.489237487316132, 6.289390981197357, 6.035727024078369, 5.865457147359848, 5.476056098937988, 5.232875674962997, 4.954723089933395, 4.778095573186874, 4.570939838886261, 4.380543678998947, 3.9323702454566956, 3.5059328824281693, 3.353578120470047, 3.067207172513008, 2.7971176207065582, 2.4704372584819794, 2.2156896740198135, 2.5211153775453568, 2.1998740285634995, 1.9562402442097664, 1.7677932605147362, 1.5303428247570992, 1.440201759338379, 1.2664306908845901, 1.2097508683800697, 1.091821638867259, 1.0088547989726067]
1.1945710182189941
Accuracy: 0.7236842105263158
Precision, Recall, F1: (0.5365853658536586, 0.4888888888888889, 0.5116279069767442, None)


  3%|▎         | 1/30 [00:00<00:16,  1.77it/s]

Loss: 7.231850445270538


  7%|▋         | 2/30 [00:01<00:15,  1.77it/s]

Loss: 6.68852311372757


 10%|█         | 3/30 [00:01<00:15,  1.77it/s]

Loss: 6.634013772010803


 13%|█▎        | 4/30 [00:02<00:14,  1.77it/s]

Loss: 6.549228727817535


 17%|█▋        | 5/30 [00:02<00:14,  1.78it/s]

Loss: 6.296201825141907


 20%|██        | 6/30 [00:03<00:13,  1.78it/s]

Loss: 5.883107304573059


 23%|██▎       | 7/30 [00:03<00:12,  1.78it/s]

Loss: 5.525115102529526


 27%|██▋       | 8/30 [00:04<00:12,  1.78it/s]

Loss: 5.18276783823967


 30%|███       | 9/30 [00:05<00:11,  1.78it/s]

Loss: 5.053667277097702


 33%|███▎      | 10/30 [00:05<00:11,  1.77it/s]

Loss: 4.660975605249405


 37%|███▋      | 11/30 [00:06<00:10,  1.77it/s]

Loss: 4.4409023225307465


 40%|████      | 12/30 [00:06<00:10,  1.77it/s]

Loss: 4.149117171764374


 43%|████▎     | 13/30 [00:07<00:09,  1.78it/s]

Loss: 3.8701827228069305


 47%|████▋     | 14/30 [00:07<00:09,  1.77it/s]

Loss: 3.8316816687583923


 50%|█████     | 15/30 [00:08<00:08,  1.77it/s]

Loss: 3.484685778617859


 53%|█████▎    | 16/30 [00:09<00:07,  1.77it/s]

Loss: 3.0067983269691467


 57%|█████▋    | 17/30 [00:09<00:07,  1.77it/s]

Loss: 2.7451503425836563


 60%|██████    | 18/30 [00:10<00:06,  1.77it/s]

Loss: 2.3380885273218155


 63%|██████▎   | 19/30 [00:10<00:06,  1.77it/s]

Loss: 1.99125574529171


 67%|██████▋   | 20/30 [00:11<00:05,  1.78it/s]

Loss: 1.6317827180027962


 70%|███████   | 21/30 [00:11<00:05,  1.78it/s]

Loss: 1.4866941720247269


 73%|███████▎  | 22/30 [00:12<00:04,  1.77it/s]

Loss: 1.2757514715194702


 77%|███████▋  | 23/30 [00:12<00:03,  1.78it/s]

Loss: 1.016661636531353


 80%|████████  | 24/30 [00:13<00:03,  1.78it/s]

Loss: 0.9276844542473555


 83%|████████▎ | 25/30 [00:14<00:02,  1.77it/s]

Loss: 0.800710704177618


 87%|████████▋ | 26/30 [00:14<00:02,  1.77it/s]

Loss: 0.7684451304376125


 90%|█████████ | 27/30 [00:15<00:01,  1.77it/s]

Loss: 0.7151028299704194


 93%|█████████▎| 28/30 [00:15<00:01,  1.76it/s]

Loss: 0.6270913388580084


 97%|█████████▋| 29/30 [00:16<00:00,  1.76it/s]

Loss: 0.5949679501354694


100%|██████████| 30/30 [00:16<00:00,  1.77it/s]
  "num_layers={}".format(dropout, num_layers))


Loss: 0.7113088136538863
[7.231850445270538, 6.68852311372757, 6.634013772010803, 6.549228727817535, 6.296201825141907, 5.883107304573059, 5.525115102529526, 5.18276783823967, 5.053667277097702, 4.660975605249405, 4.4409023225307465, 4.149117171764374, 3.8701827228069305, 3.8316816687583923, 3.484685778617859, 3.0067983269691467, 2.7451503425836563, 2.3380885273218155, 1.99125574529171, 1.6317827180027962, 1.4866941720247269, 1.2757514715194702, 1.016661636531353, 0.9276844542473555, 0.800710704177618, 0.7684451304376125, 0.7151028299704194, 0.6270913388580084, 0.5949679501354694, 0.7113088136538863]
1.390955924987793
Accuracy: 0.8223684210526315
Precision, Recall, F1: (0.6285714285714286, 0.6111111111111112, 0.619718309859155, None)


  3%|▎         | 1/30 [00:00<00:16,  1.77it/s]

Loss: 7.4496824741363525


  7%|▋         | 2/30 [00:01<00:15,  1.76it/s]

Loss: 6.730735182762146


 10%|█         | 3/30 [00:01<00:15,  1.76it/s]

Loss: 6.616629600524902


 13%|█▎        | 4/30 [00:02<00:14,  1.77it/s]

Loss: 6.493064105510712


 17%|█▋        | 5/30 [00:02<00:14,  1.77it/s]

Loss: 6.417547523975372


 20%|██        | 6/30 [00:03<00:13,  1.77it/s]

Loss: 6.2276692390441895


 23%|██▎       | 7/30 [00:03<00:13,  1.77it/s]

Loss: 5.855240315198898


 27%|██▋       | 8/30 [00:04<00:12,  1.76it/s]

Loss: 5.522601693868637


 30%|███       | 9/30 [00:05<00:11,  1.76it/s]

Loss: 5.232211410999298


 33%|███▎      | 10/30 [00:05<00:11,  1.76it/s]

Loss: 4.84644889831543


 37%|███▋      | 11/30 [00:06<00:10,  1.76it/s]

Loss: 4.549790740013123


 40%|████      | 12/30 [00:06<00:10,  1.76it/s]

Loss: 4.239611715078354


 43%|████▎     | 13/30 [00:07<00:09,  1.76it/s]

Loss: 3.8885496854782104


 47%|████▋     | 14/30 [00:07<00:09,  1.77it/s]

Loss: 3.439583331346512


 50%|█████     | 15/30 [00:08<00:08,  1.76it/s]

Loss: 3.174232006072998


 53%|█████▎    | 16/30 [00:09<00:07,  1.76it/s]

Loss: 2.8356460630893707


 57%|█████▋    | 17/30 [00:09<00:07,  1.76it/s]

Loss: 2.4220950305461884


 60%|██████    | 18/30 [00:10<00:06,  1.76it/s]

Loss: 2.3805021345615387


 63%|██████▎   | 19/30 [00:10<00:06,  1.76it/s]

Loss: 2.0988940075039864


 67%|██████▋   | 20/30 [00:11<00:05,  1.77it/s]

Loss: 1.844334416091442


 70%|███████   | 21/30 [00:11<00:05,  1.78it/s]

Loss: 1.5292153134942055


 73%|███████▎  | 22/30 [00:12<00:04,  1.78it/s]

Loss: 1.316838264465332


 77%|███████▋  | 23/30 [00:13<00:03,  1.77it/s]

Loss: 1.1831741072237492


 80%|████████  | 24/30 [00:13<00:03,  1.77it/s]

Loss: 1.1247880682349205


 83%|████████▎ | 25/30 [00:14<00:02,  1.76it/s]

Loss: 1.0676222499459982


 87%|████████▋ | 26/30 [00:14<00:02,  1.76it/s]

Loss: 1.0434539504349232


 90%|█████████ | 27/30 [00:15<00:01,  1.76it/s]

Loss: 1.1897021271288395


 93%|█████████▎| 28/30 [00:15<00:01,  1.76it/s]

Loss: 2.30994376540184


 97%|█████████▋| 29/30 [00:16<00:00,  1.77it/s]

Loss: 1.5454294234514236


100%|██████████| 30/30 [00:16<00:00,  1.76it/s]
  "num_layers={}".format(dropout, num_layers))


Loss: 1.2099626045674086
[7.4496824741363525, 6.730735182762146, 6.616629600524902, 6.493064105510712, 6.417547523975372, 6.2276692390441895, 5.855240315198898, 5.522601693868637, 5.232211410999298, 4.84644889831543, 4.549790740013123, 4.239611715078354, 3.8885496854782104, 3.439583331346512, 3.174232006072998, 2.8356460630893707, 2.4220950305461884, 2.3805021345615387, 2.0988940075039864, 1.844334416091442, 1.5292153134942055, 1.316838264465332, 1.1831741072237492, 1.1247880682349205, 1.0676222499459982, 1.0434539504349232, 1.1897021271288395, 2.30994376540184, 1.5454294234514236, 1.2099626045674086]
0.9035958051681519
Accuracy: 0.75
Precision, Recall, F1: (0.631578947368421, 0.5, 0.5581395348837209, None)


  3%|▎         | 1/30 [00:00<00:16,  1.77it/s]

Loss: 7.219363808631897


  7%|▋         | 2/30 [00:01<00:15,  1.76it/s]

Loss: 6.6780752539634705


 10%|█         | 3/30 [00:01<00:15,  1.76it/s]

Loss: 6.610791921615601


 13%|█▎        | 4/30 [00:02<00:14,  1.76it/s]

Loss: 6.580202996730804


 17%|█▋        | 5/30 [00:02<00:14,  1.76it/s]

Loss: 6.386736810207367


 20%|██        | 6/30 [00:03<00:13,  1.76it/s]

Loss: 6.050075888633728


 23%|██▎       | 7/30 [00:03<00:13,  1.76it/s]

Loss: 5.630377531051636


 27%|██▋       | 8/30 [00:04<00:12,  1.75it/s]

Loss: 5.362293064594269


 30%|███       | 9/30 [00:05<00:11,  1.76it/s]

Loss: 5.13444197177887


 33%|███▎      | 10/30 [00:05<00:11,  1.75it/s]

Loss: 4.820311993360519


 37%|███▋      | 11/30 [00:06<00:10,  1.75it/s]

Loss: 4.6450159549713135


 40%|████      | 12/30 [00:06<00:10,  1.76it/s]

Loss: 4.563366919755936


 43%|████▎     | 13/30 [00:07<00:09,  1.76it/s]

Loss: 4.211198419332504


 47%|████▋     | 14/30 [00:07<00:09,  1.76it/s]

Loss: 3.991465210914612


 50%|█████     | 15/30 [00:08<00:08,  1.76it/s]

Loss: 3.6322764605283737


 53%|█████▎    | 16/30 [00:09<00:07,  1.75it/s]

Loss: 3.6060781478881836


 57%|█████▋    | 17/30 [00:09<00:07,  1.75it/s]

Loss: 3.31862074136734


 60%|██████    | 18/30 [00:10<00:06,  1.75it/s]

Loss: 2.7853978276252747


 63%|██████▎   | 19/30 [00:10<00:06,  1.75it/s]

Loss: 2.4806663542985916


 67%|██████▋   | 20/30 [00:11<00:05,  1.76it/s]

Loss: 2.2687553614377975


 70%|███████   | 21/30 [00:11<00:05,  1.76it/s]

Loss: 2.4069131165742874


 73%|███████▎  | 22/30 [00:12<00:04,  1.76it/s]

Loss: 2.185056358575821


 77%|███████▋  | 23/30 [00:13<00:03,  1.76it/s]

Loss: 1.8169047385454178


 80%|████████  | 24/30 [00:13<00:03,  1.75it/s]

Loss: 1.5011714026331902


 83%|████████▎ | 25/30 [00:14<00:02,  1.75it/s]

Loss: 1.2844415344297886


 87%|████████▋ | 26/30 [00:14<00:02,  1.76it/s]

Loss: 1.335901103913784


 90%|█████████ | 27/30 [00:15<00:01,  1.75it/s]

Loss: 1.27716900780797


 93%|█████████▎| 28/30 [00:15<00:01,  1.75it/s]

Loss: 1.0620783604681492


 97%|█████████▋| 29/30 [00:16<00:00,  1.75it/s]

Loss: 0.9128562957048416


100%|██████████| 30/30 [00:17<00:00,  1.76it/s]
  "num_layers={}".format(dropout, num_layers))


Loss: 0.8399140611290932
[7.219363808631897, 6.6780752539634705, 6.610791921615601, 6.580202996730804, 6.386736810207367, 6.050075888633728, 5.630377531051636, 5.362293064594269, 5.13444197177887, 4.820311993360519, 4.6450159549713135, 4.563366919755936, 4.211198419332504, 3.991465210914612, 3.6322764605283737, 3.6060781478881836, 3.31862074136734, 2.7853978276252747, 2.4806663542985916, 2.2687553614377975, 2.4069131165742874, 2.185056358575821, 1.8169047385454178, 1.5011714026331902, 1.2844415344297886, 1.335901103913784, 1.27716900780797, 1.0620783604681492, 0.9128562957048416, 0.8399140611290932]
0.9738895893096924
Accuracy: 0.75
Precision, Recall, F1: (0.5862068965517241, 0.3953488372093023, 0.4722222222222222, None)


  3%|▎         | 1/30 [00:00<00:16,  1.79it/s]

Loss: 7.440488517284393


  7%|▋         | 2/30 [00:01<00:15,  1.78it/s]

Loss: 6.657531201839447


 10%|█         | 3/30 [00:01<00:15,  1.77it/s]

Loss: 6.602605819702148


 13%|█▎        | 4/30 [00:02<00:14,  1.76it/s]

Loss: 6.558432936668396


 17%|█▋        | 5/30 [00:02<00:14,  1.76it/s]

Loss: 6.4616809487342834


 20%|██        | 6/30 [00:03<00:13,  1.76it/s]

Loss: 6.130834877490997


 23%|██▎       | 7/30 [00:03<00:13,  1.76it/s]

Loss: 5.816828608512878


 27%|██▋       | 8/30 [00:04<00:12,  1.76it/s]

Loss: 5.513843268156052


 30%|███       | 9/30 [00:05<00:11,  1.77it/s]

Loss: 5.3184908628463745


 33%|███▎      | 10/30 [00:05<00:11,  1.76it/s]

Loss: 5.067902624607086


 37%|███▋      | 11/30 [00:06<00:10,  1.76it/s]

Loss: 4.8336136639118195


 40%|████      | 12/30 [00:06<00:10,  1.76it/s]

Loss: 4.566974431276321


 43%|████▎     | 13/30 [00:07<00:09,  1.76it/s]

Loss: 4.200301796197891


 47%|████▋     | 14/30 [00:07<00:09,  1.76it/s]

Loss: 3.9230076670646667


 50%|█████     | 15/30 [00:08<00:08,  1.77it/s]

Loss: 3.5247256010770798


 53%|█████▎    | 16/30 [00:09<00:07,  1.76it/s]

Loss: 3.2339989989995956


 57%|█████▋    | 17/30 [00:09<00:07,  1.76it/s]

Loss: 3.1655759662389755


 60%|██████    | 18/30 [00:10<00:06,  1.76it/s]

Loss: 2.823238968849182


 63%|██████▎   | 19/30 [00:10<00:06,  1.76it/s]

Loss: 2.4697202891111374


 67%|██████▋   | 20/30 [00:11<00:05,  1.75it/s]

Loss: 2.2571995556354523


 70%|███████   | 21/30 [00:11<00:05,  1.75it/s]

Loss: 2.1001860573887825


 73%|███████▎  | 22/30 [00:12<00:04,  1.76it/s]

Loss: 1.6869202852249146


 77%|███████▋  | 23/30 [00:13<00:03,  1.76it/s]

Loss: 1.457093633711338


 80%|████████  | 24/30 [00:13<00:03,  1.76it/s]

Loss: 1.4091816805303097


 83%|████████▎ | 25/30 [00:14<00:02,  1.76it/s]

Loss: 1.292484112083912


 87%|████████▋ | 26/30 [00:14<00:02,  1.77it/s]

Loss: 1.1231594122946262


 90%|█████████ | 27/30 [00:15<00:01,  1.77it/s]

Loss: 1.1054626312106848


 93%|█████████▎| 28/30 [00:15<00:01,  1.76it/s]

Loss: 1.1192189753055573


 97%|█████████▋| 29/30 [00:16<00:00,  1.76it/s]

Loss: 1.3192428685724735


100%|██████████| 30/30 [00:17<00:00,  1.76it/s]
  "num_layers={}".format(dropout, num_layers))


Loss: 0.9865379929542542
[7.440488517284393, 6.657531201839447, 6.602605819702148, 6.558432936668396, 6.4616809487342834, 6.130834877490997, 5.816828608512878, 5.513843268156052, 5.3184908628463745, 5.067902624607086, 4.8336136639118195, 4.566974431276321, 4.200301796197891, 3.9230076670646667, 3.5247256010770798, 3.2339989989995956, 3.1655759662389755, 2.823238968849182, 2.4697202891111374, 2.2571995556354523, 2.1001860573887825, 1.6869202852249146, 1.457093633711338, 1.4091816805303097, 1.292484112083912, 1.1231594122946262, 1.1054626312106848, 1.1192189753055573, 1.3192428685724735, 0.9865379929542542]
0.8304365873336792
Accuracy: 0.7828947368421053
Precision, Recall, F1: (0.5833333333333334, 0.5384615384615384, 0.5599999999999999, None)


  3%|▎         | 1/30 [00:00<00:16,  1.78it/s]

Loss: 7.392193794250488


  7%|▋         | 2/30 [00:01<00:15,  1.76it/s]

Loss: 6.5906625390052795


 10%|█         | 3/30 [00:01<00:15,  1.77it/s]

Loss: 6.499975919723511


 13%|█▎        | 4/30 [00:02<00:14,  1.77it/s]

Loss: 6.48433393239975


 17%|█▋        | 5/30 [00:02<00:14,  1.77it/s]

Loss: 6.370480716228485


 20%|██        | 6/30 [00:03<00:13,  1.77it/s]

Loss: 6.103454828262329


 23%|██▎       | 7/30 [00:03<00:12,  1.77it/s]

Loss: 5.683687090873718


 27%|██▋       | 8/30 [00:04<00:12,  1.77it/s]

Loss: 5.439664512872696


 30%|███       | 9/30 [00:05<00:11,  1.77it/s]

Loss: 5.193986564874649


 33%|███▎      | 10/30 [00:05<00:11,  1.77it/s]

Loss: 4.932405084371567


 37%|███▋      | 11/30 [00:06<00:10,  1.76it/s]

Loss: 4.721827685832977


 40%|████      | 12/30 [00:06<00:10,  1.75it/s]

Loss: 4.413560211658478


 43%|████▎     | 13/30 [00:07<00:09,  1.75it/s]

Loss: 4.300531327724457


 47%|████▋     | 14/30 [00:07<00:09,  1.75it/s]

Loss: 4.0747604966163635


 50%|█████     | 15/30 [00:08<00:08,  1.76it/s]

Loss: 3.7569516599178314


 53%|█████▎    | 16/30 [00:09<00:07,  1.76it/s]

Loss: 3.3218430280685425


 57%|█████▋    | 17/30 [00:09<00:07,  1.76it/s]

Loss: 2.9172238260507584


 60%|██████    | 18/30 [00:10<00:06,  1.76it/s]

Loss: 2.590586841106415


 63%|██████▎   | 19/30 [00:10<00:06,  1.75it/s]

Loss: 2.6219917982816696


 67%|██████▋   | 20/30 [00:11<00:05,  1.75it/s]

Loss: 2.3464084565639496


 70%|███████   | 21/30 [00:11<00:05,  1.75it/s]

Loss: 1.9531150981783867


 73%|███████▎  | 22/30 [00:12<00:04,  1.76it/s]

Loss: 1.6486174762248993


 77%|███████▋  | 23/30 [00:13<00:03,  1.76it/s]

Loss: 1.521292395889759


 80%|████████  | 24/30 [00:13<00:03,  1.76it/s]

Loss: 1.2890072502195835


 83%|████████▎ | 25/30 [00:14<00:02,  1.75it/s]

Loss: 1.1928284615278244


 87%|████████▋ | 26/30 [00:14<00:02,  1.75it/s]

Loss: 1.4359319359064102


 90%|█████████ | 27/30 [00:15<00:01,  1.76it/s]

Loss: 1.4571885764598846


 93%|█████████▎| 28/30 [00:15<00:01,  1.76it/s]

Loss: 1.2202797308564186


 97%|█████████▋| 29/30 [00:16<00:00,  1.77it/s]

Loss: 1.1002238765358925


100%|██████████| 30/30 [00:17<00:00,  1.76it/s]
  "num_layers={}".format(dropout, num_layers))


Loss: 1.0902759209275246
[7.392193794250488, 6.5906625390052795, 6.499975919723511, 6.48433393239975, 6.370480716228485, 6.103454828262329, 5.683687090873718, 5.439664512872696, 5.193986564874649, 4.932405084371567, 4.721827685832977, 4.413560211658478, 4.300531327724457, 4.0747604966163635, 3.7569516599178314, 3.3218430280685425, 2.9172238260507584, 2.590586841106415, 2.6219917982816696, 2.3464084565639496, 1.9531150981783867, 1.6486174762248993, 1.521292395889759, 1.2890072502195835, 1.1928284615278244, 1.4359319359064102, 1.4571885764598846, 1.2202797308564186, 1.1002238765358925, 1.0902759209275246]
0.9111956357955933
Accuracy: 0.7631578947368421
Precision, Recall, F1: (0.6538461538461539, 0.6538461538461539, 0.6538461538461539, None)


  3%|▎         | 1/30 [00:00<00:16,  1.76it/s]

Loss: 7.266852557659149


  7%|▋         | 2/30 [00:01<00:15,  1.77it/s]

Loss: 6.71310567855835


 10%|█         | 3/30 [00:01<00:15,  1.78it/s]

Loss: 6.639466881752014


 13%|█▎        | 4/30 [00:02<00:14,  1.78it/s]

Loss: 6.558628141880035


 17%|█▋        | 5/30 [00:02<00:14,  1.77it/s]

Loss: 6.437647223472595


 20%|██        | 6/30 [00:03<00:13,  1.78it/s]

Loss: 5.943740874528885


 23%|██▎       | 7/30 [00:03<00:12,  1.78it/s]

Loss: 5.6292645037174225


 27%|██▋       | 8/30 [00:04<00:12,  1.77it/s]

Loss: 5.344928801059723


 30%|███       | 9/30 [00:05<00:11,  1.77it/s]

Loss: 5.062857627868652


 33%|███▎      | 10/30 [00:05<00:11,  1.77it/s]

Loss: 4.897380739450455


 37%|███▋      | 11/30 [00:06<00:10,  1.77it/s]

Loss: 4.595435827970505


 40%|████      | 12/30 [00:06<00:10,  1.76it/s]

Loss: 4.333438366651535


 43%|████▎     | 13/30 [00:07<00:09,  1.77it/s]

Loss: 4.11571004986763


 47%|████▋     | 14/30 [00:07<00:09,  1.77it/s]

Loss: 3.8402989208698273


 50%|█████     | 15/30 [00:08<00:08,  1.77it/s]

Loss: 3.5193296670913696


 53%|█████▎    | 16/30 [00:09<00:07,  1.77it/s]

Loss: 3.309605821967125


 57%|█████▋    | 17/30 [00:09<00:07,  1.77it/s]

Loss: 2.85959854722023


 60%|██████    | 18/30 [00:10<00:06,  1.77it/s]

Loss: 2.571341559290886


 63%|██████▎   | 19/30 [00:10<00:06,  1.77it/s]

Loss: 2.3248498886823654


 67%|██████▋   | 20/30 [00:11<00:05,  1.77it/s]

Loss: 1.9562107026576996


 70%|███████   | 21/30 [00:11<00:05,  1.77it/s]

Loss: 1.7725707367062569


 73%|███████▎  | 22/30 [00:12<00:04,  1.76it/s]

Loss: 1.4970396012067795


 77%|███████▋  | 23/30 [00:13<00:03,  1.76it/s]

Loss: 1.5050532147288322


 80%|████████  | 24/30 [00:13<00:03,  1.76it/s]

Loss: 1.263139270246029


 83%|████████▎ | 25/30 [00:14<00:02,  1.76it/s]

Loss: 1.0365842655301094


 87%|████████▋ | 26/30 [00:14<00:02,  1.76it/s]

Loss: 1.1314156018197536


 90%|█████████ | 27/30 [00:15<00:01,  1.77it/s]

Loss: 1.9125632140785456


 93%|█████████▎| 28/30 [00:15<00:01,  1.77it/s]

Loss: 1.3374944776296616


 97%|█████████▋| 29/30 [00:16<00:00,  1.77it/s]

Loss: 1.0989512540400028


100%|██████████| 30/30 [00:16<00:00,  1.77it/s]
  "num_layers={}".format(dropout, num_layers))


Loss: 1.018294584006071
[7.266852557659149, 6.71310567855835, 6.639466881752014, 6.558628141880035, 6.437647223472595, 5.943740874528885, 5.6292645037174225, 5.344928801059723, 5.062857627868652, 4.897380739450455, 4.595435827970505, 4.333438366651535, 4.11571004986763, 3.8402989208698273, 3.5193296670913696, 3.309605821967125, 2.85959854722023, 2.571341559290886, 2.3248498886823654, 1.9562107026576996, 1.7725707367062569, 1.4970396012067795, 1.5050532147288322, 1.263139270246029, 1.0365842655301094, 1.1314156018197536, 1.9125632140785456, 1.3374944776296616, 1.0989512540400028, 1.018294584006071]
1.1473394632339478
Accuracy: 0.6776315789473685
Precision, Recall, F1: (0.35714285714285715, 0.40540540540540543, 0.379746835443038, None)


  3%|▎         | 1/30 [00:00<00:16,  1.79it/s]

Loss: 7.37364661693573


  7%|▋         | 2/30 [00:01<00:15,  1.78it/s]

Loss: 6.783625841140747


 10%|█         | 3/30 [00:01<00:15,  1.78it/s]

Loss: 6.575734853744507


 13%|█▎        | 4/30 [00:02<00:14,  1.78it/s]

Loss: 6.530081748962402


 17%|█▋        | 5/30 [00:02<00:14,  1.78it/s]

Loss: 6.409852147102356


 20%|██        | 6/30 [00:03<00:13,  1.78it/s]

Loss: 6.122947990894318


 23%|██▎       | 7/30 [00:03<00:12,  1.78it/s]

Loss: 5.68210956454277


 27%|██▋       | 8/30 [00:04<00:12,  1.78it/s]

Loss: 5.438181430101395


 30%|███       | 9/30 [00:05<00:11,  1.78it/s]

Loss: 5.194558024406433


 33%|███▎      | 10/30 [00:05<00:11,  1.78it/s]

Loss: 4.935616135597229


 37%|███▋      | 11/30 [00:06<00:10,  1.78it/s]

Loss: 4.5959742069244385


 40%|████      | 12/30 [00:06<00:10,  1.79it/s]

Loss: 4.447739332914352


 43%|████▎     | 13/30 [00:07<00:09,  1.80it/s]

Loss: 4.165482252836227


 47%|████▋     | 14/30 [00:07<00:08,  1.79it/s]

Loss: 3.9844868779182434


 50%|█████     | 15/30 [00:08<00:08,  1.79it/s]

Loss: 3.9051913022994995


 53%|█████▎    | 16/30 [00:08<00:07,  1.79it/s]

Loss: 3.6586190164089203


 57%|█████▋    | 17/30 [00:09<00:07,  1.79it/s]

Loss: 3.191906690597534


 60%|██████    | 18/30 [00:10<00:06,  1.79it/s]

Loss: 2.897287353873253


 63%|██████▎   | 19/30 [00:10<00:06,  1.78it/s]

Loss: 2.6479208022356033


 67%|██████▋   | 20/30 [00:11<00:05,  1.79it/s]

Loss: 2.2628661394119263


 70%|███████   | 21/30 [00:11<00:05,  1.78it/s]

Loss: 2.0204537957906723


 73%|███████▎  | 22/30 [00:12<00:04,  1.78it/s]

Loss: 1.8395610302686691


 77%|███████▋  | 23/30 [00:12<00:03,  1.78it/s]

Loss: 1.5682322084903717


 80%|████████  | 24/30 [00:13<00:03,  1.77it/s]

Loss: 1.4964762702584267


 83%|████████▎ | 25/30 [00:14<00:02,  1.77it/s]

Loss: 1.3181863576173782


 87%|████████▋ | 26/30 [00:14<00:02,  1.77it/s]

Loss: 1.326012548059225


 90%|█████████ | 27/30 [00:15<00:01,  1.77it/s]

Loss: 1.1042156219482422


 93%|█████████▎| 28/30 [00:15<00:01,  1.77it/s]

Loss: 1.16416434943676


 97%|█████████▋| 29/30 [00:16<00:00,  1.77it/s]

Loss: 1.0927966330200434


100%|██████████| 30/30 [00:16<00:00,  1.78it/s]
  "num_layers={}".format(dropout, num_layers))


Loss: 1.1024193093180656
[7.37364661693573, 6.783625841140747, 6.575734853744507, 6.530081748962402, 6.409852147102356, 6.122947990894318, 5.68210956454277, 5.438181430101395, 5.194558024406433, 4.935616135597229, 4.5959742069244385, 4.447739332914352, 4.165482252836227, 3.9844868779182434, 3.9051913022994995, 3.6586190164089203, 3.191906690597534, 2.897287353873253, 2.6479208022356033, 2.2628661394119263, 2.0204537957906723, 1.8395610302686691, 1.5682322084903717, 1.4964762702584267, 1.3181863576173782, 1.326012548059225, 1.1042156219482422, 1.16416434943676, 1.0927966330200434, 1.1024193093180656]
1.0661855936050415
Accuracy: 0.7631578947368421
Precision, Recall, F1: (0.5454545454545454, 0.6, 0.5714285714285713, None)


  3%|▎         | 1/30 [00:00<00:16,  1.75it/s]

Loss: 7.278657615184784


  7%|▋         | 2/30 [00:01<00:16,  1.75it/s]

Loss: 6.682006597518921


 10%|█         | 3/30 [00:01<00:15,  1.75it/s]

Loss: 6.509561896324158


 13%|█▎        | 4/30 [00:02<00:14,  1.75it/s]

Loss: 6.4828126430511475


 17%|█▋        | 5/30 [00:02<00:14,  1.74it/s]

Loss: 6.351690113544464


 20%|██        | 6/30 [00:03<00:13,  1.75it/s]

Loss: 6.047308087348938


 23%|██▎       | 7/30 [00:04<00:13,  1.75it/s]

Loss: 5.648025363683701


 27%|██▋       | 8/30 [00:04<00:12,  1.74it/s]

Loss: 5.4414012134075165


 30%|███       | 9/30 [00:05<00:12,  1.74it/s]

Loss: 5.2005860805511475


 33%|███▎      | 10/30 [00:05<00:11,  1.74it/s]

Loss: 4.8962951600551605


 37%|███▋      | 11/30 [00:06<00:10,  1.74it/s]

Loss: 4.6472903192043304


 40%|████      | 12/30 [00:06<00:10,  1.74it/s]

Loss: 4.415793389081955


 43%|████▎     | 13/30 [00:07<00:09,  1.74it/s]

Loss: 4.208315819501877


 47%|████▋     | 14/30 [00:08<00:09,  1.74it/s]

Loss: 3.799087628722191


 50%|█████     | 15/30 [00:08<00:08,  1.74it/s]

Loss: 3.5180020928382874


 53%|█████▎    | 16/30 [00:09<00:08,  1.74it/s]

Loss: 2.992550954222679


 57%|█████▋    | 17/30 [00:09<00:07,  1.75it/s]

Loss: 2.7367721796035767


 60%|██████    | 18/30 [00:10<00:06,  1.75it/s]

Loss: 2.3890950977802277


 63%|██████▎   | 19/30 [00:10<00:06,  1.75it/s]

Loss: 2.315435916185379


 67%|██████▋   | 20/30 [00:11<00:05,  1.75it/s]

Loss: 2.0635610967874527


 70%|███████   | 21/30 [00:12<00:05,  1.75it/s]

Loss: 1.925570860505104


 73%|███████▎  | 22/30 [00:12<00:04,  1.75it/s]

Loss: 1.6632335186004639


 77%|███████▋  | 23/30 [00:13<00:04,  1.74it/s]

Loss: 1.6265090927481651


 80%|████████  | 24/30 [00:13<00:03,  1.74it/s]

Loss: 1.3884723782539368


 83%|████████▎ | 25/30 [00:14<00:02,  1.74it/s]

Loss: 1.4331006407737732


 87%|████████▋ | 26/30 [00:14<00:02,  1.74it/s]

Loss: 1.4603261277079582


 90%|█████████ | 27/30 [00:15<00:01,  1.74it/s]

Loss: 1.3621400818228722


 93%|█████████▎| 28/30 [00:16<00:01,  1.74it/s]

Loss: 1.266315996646881


 97%|█████████▋| 29/30 [00:16<00:00,  1.74it/s]

Loss: 1.4715285077691078


100%|██████████| 30/30 [00:17<00:00,  1.74it/s]

Loss: 1.4113472774624825
[7.278657615184784, 6.682006597518921, 6.509561896324158, 6.4828126430511475, 6.351690113544464, 6.047308087348938, 5.648025363683701, 5.4414012134075165, 5.2005860805511475, 4.8962951600551605, 4.6472903192043304, 4.415793389081955, 4.208315819501877, 3.799087628722191, 3.5180020928382874, 2.992550954222679, 2.7367721796035767, 2.3890950977802277, 2.315435916185379, 2.0635610967874527, 1.925570860505104, 1.6632335186004639, 1.6265090927481651, 1.3884723782539368, 1.4331006407737732, 1.4603261277079582, 1.3621400818228722, 1.266315996646881, 1.4715285077691078, 1.4113472774624825]
1.0908381938934326
Accuracy: 0.7236842105263158
Precision, Recall, F1: (0.6071428571428571, 0.3541666666666667, 0.4473684210526316, None)

===Aggregate Stats===
Accuracy: 0.7486842105263158
Precision, Recall, F1: (0.5657894736842105, 0.4976851851851852, 0.5295566502463054, None)





In [None]:
def kfold_crossvalidation_titles(data, titles, labels, k, n_epochs = 30, model=None):
    _, num_samples, _ = data.shape
    fraction = 1/k
    seg = int(num_samples * fraction)
    segment_indices = []
    for i in range(k):
        vall = i * seg
        valr = i * seg + seg
        segment_indices.append(list(range(vall,valr)))
    
    all_preds = []
    all_labels = []
    #actually run the ith split
    for i in range(k):
        train_indices = []
        test_indices = segment_indices[i]
        for j in range(k):
            if j != i:
                train_indices.extend(segment_indices[j])

    
        train_data = data[:,train_indices,:]
        train_titles = titles[:,train_indices,:]
        train_labels = labels[train_indices]

        test_data = data[:,test_indices,:]
        test_titles = titles[:,test_indices,:]
        test_labels = labels[test_indices]

        batch_size = 128
        model_i = train_with_titles(train_data, train_titles, train_labels, n_epochs, batch_size)

        iter_loss, (y_pred, y_true) = test_model_titles(model_i,test_data,test_titles,test_labels)
        y_pred = torch.round(y_pred).cpu().detach().numpy()
        all_preds.append(y_pred)
        all_labels.append(y_true)
        #print(type(y_true), y_true.shape)
        #print(type(y_pred), y_pred.shape)
        print('Accuracy:',accuracy_score(y_true,y_pred))
        print('Precision, Recall, F1:',precision_recall_fscore_support(y_true, y_pred, average='binary'))
    
    print('\n===Aggregate Stats===')
    p = np.concatenate(all_preds, axis = None)
    l = np.concatenate(all_labels, axis = None)
    print('Accuracy:', accuracy_score(l, p))
    print('Precision, Recall, F1:', precision_recall_fscore_support(l, p, average = 'binary'))

def test_model_titles(model, test_data, test_titles, test_labels):
    test_dataset = Dataset(range(len(test_labels)), test_data, test_labels, titles = test_titles)

    test_loader = torch_data.DataLoader(test_dataset, batch_size=len(test_labels))
    loss_fn = nn.BCELoss()

    predictions = None
    
    for context, t, label in test_loader:
        context = context.to(device)
        context = context.moveaxis(0, 1)
        t = t.to(device)
        t = t.moveaxis(0,1)
        label = label.to(device).type(torch.float32)

        #preds is a tensor of roughly torch.Size([305])
        preds = model.forward(context,t)
        predictions = preds

        loss = loss_fn(preds, label)

        print(loss.item())
    

    return loss.item(), (preds, test_labels)

kfold_crossvalidation_titles(train_data_array,train_title_array,train_labels,10,n_epochs = 40, model = None)

  "num_layers={}".format(dropout, num_layers))
  2%|▎         | 1/40 [00:00<00:25,  1.54it/s]

Loss: 7.006661653518677


  5%|▌         | 2/40 [00:01<00:24,  1.55it/s]

Loss: 6.4966921210289


  8%|▊         | 3/40 [00:01<00:23,  1.55it/s]

Loss: 6.378241539001465


 10%|█         | 4/40 [00:02<00:23,  1.55it/s]

Loss: 6.2170809507369995


 12%|█▎        | 5/40 [00:03<00:22,  1.55it/s]

Loss: 6.008680552244186


 15%|█▌        | 6/40 [00:03<00:21,  1.55it/s]

Loss: 5.741568148136139


 18%|█▊        | 7/40 [00:04<00:21,  1.56it/s]

Loss: 5.524868458509445


 20%|██        | 8/40 [00:05<00:20,  1.56it/s]

Loss: 5.258590847253799


 22%|██▎       | 9/40 [00:05<00:19,  1.56it/s]

Loss: 4.968223065137863


 25%|██▌       | 10/40 [00:06<00:19,  1.56it/s]

Loss: 4.750721424818039


 28%|██▊       | 11/40 [00:07<00:18,  1.56it/s]

Loss: 4.454420685768127


 30%|███       | 12/40 [00:07<00:17,  1.56it/s]

Loss: 4.3646558821201324


 32%|███▎      | 13/40 [00:08<00:17,  1.56it/s]

Loss: 3.685142919421196


 35%|███▌      | 14/40 [00:08<00:16,  1.56it/s]

Loss: 3.232719451189041


 38%|███▊      | 15/40 [00:09<00:15,  1.57it/s]

Loss: 2.9353972524404526


 40%|████      | 16/40 [00:10<00:15,  1.56it/s]

Loss: 2.635027453303337


 42%|████▎     | 17/40 [00:10<00:14,  1.56it/s]

Loss: 2.2702198326587677


 45%|████▌     | 18/40 [00:11<00:14,  1.56it/s]

Loss: 1.9983414113521576


 48%|████▊     | 19/40 [00:12<00:13,  1.56it/s]

Loss: 1.8451374173164368


 50%|█████     | 20/40 [00:12<00:12,  1.56it/s]

Loss: 1.7537621967494488


 52%|█████▎    | 21/40 [00:13<00:12,  1.56it/s]

Loss: 1.5638376250863075


 55%|█████▌    | 22/40 [00:14<00:11,  1.56it/s]

Loss: 1.5117861777544022


 57%|█████▊    | 23/40 [00:14<00:10,  1.56it/s]

Loss: 1.3834967017173767


 60%|██████    | 24/40 [00:15<00:10,  1.56it/s]

Loss: 1.3281264677643776


 62%|██████▎   | 25/40 [00:16<00:09,  1.55it/s]

Loss: 1.263189047574997


 65%|██████▌   | 26/40 [00:16<00:09,  1.55it/s]

Loss: 1.2168175764381886


 68%|██████▊   | 27/40 [00:17<00:08,  1.56it/s]

Loss: 1.1585466898977757


 70%|███████   | 28/40 [00:17<00:07,  1.56it/s]

Loss: 1.1316025704145432


 72%|███████▎  | 29/40 [00:18<00:07,  1.56it/s]

Loss: 1.1784398891031742


 75%|███████▌  | 30/40 [00:19<00:06,  1.56it/s]

Loss: 1.0970791056752205


 78%|███████▊  | 31/40 [00:19<00:05,  1.57it/s]

Loss: 1.0823239143937826


 80%|████████  | 32/40 [00:20<00:05,  1.57it/s]

Loss: 1.0645778700709343


 82%|████████▎ | 33/40 [00:21<00:04,  1.57it/s]

Loss: 1.0967147499322891


 85%|████████▌ | 34/40 [00:21<00:03,  1.56it/s]

Loss: 1.0862094387412071


 88%|████████▊ | 35/40 [00:22<00:03,  1.56it/s]

Loss: 1.0445865839719772


 90%|█████████ | 36/40 [00:23<00:02,  1.56it/s]

Loss: 1.0390084069222212


 92%|█████████▎| 37/40 [00:23<00:01,  1.56it/s]

Loss: 1.0323376432061195


 95%|█████████▌| 38/40 [00:24<00:01,  1.55it/s]

Loss: 1.034973356872797


 98%|█████████▊| 39/40 [00:25<00:00,  1.55it/s]

Loss: 1.0056269150227308


100%|██████████| 40/40 [00:25<00:00,  1.56it/s]
  "num_layers={}".format(dropout, num_layers))


Loss: 1.1024600062519312
[7.006661653518677, 6.4966921210289, 6.378241539001465, 6.2170809507369995, 6.008680552244186, 5.741568148136139, 5.524868458509445, 5.258590847253799, 4.968223065137863, 4.750721424818039, 4.454420685768127, 4.3646558821201324, 3.685142919421196, 3.232719451189041, 2.9353972524404526, 2.635027453303337, 2.2702198326587677, 1.9983414113521576, 1.8451374173164368, 1.7537621967494488, 1.5638376250863075, 1.5117861777544022, 1.3834967017173767, 1.3281264677643776, 1.263189047574997, 1.2168175764381886, 1.1585466898977757, 1.1316025704145432, 1.1784398891031742, 1.0970791056752205, 1.0823239143937826, 1.0645778700709343, 1.0967147499322891, 1.0862094387412071, 1.0445865839719772, 1.0390084069222212, 1.0323376432061195, 1.034973356872797, 1.0056269150227308, 1.1024600062519312]
1.1289854049682617
Accuracy: 0.7697368421052632
Precision, Recall, F1: (0.6551724137931034, 0.4318181818181818, 0.5205479452054794, None)


  2%|▎         | 1/40 [00:00<00:23,  1.63it/s]

Loss: 7.262626349925995


  5%|▌         | 2/40 [00:01<00:23,  1.62it/s]

Loss: 6.51101142168045


  8%|▊         | 3/40 [00:01<00:22,  1.61it/s]

Loss: 6.324973165988922


 10%|█         | 4/40 [00:02<00:22,  1.61it/s]

Loss: 6.164837598800659


 12%|█▎        | 5/40 [00:03<00:21,  1.61it/s]

Loss: 5.9258571565151215


 15%|█▌        | 6/40 [00:03<00:21,  1.61it/s]

Loss: 5.53159099817276


 18%|█▊        | 7/40 [00:04<00:20,  1.61it/s]

Loss: 5.404061734676361


 20%|██        | 8/40 [00:04<00:19,  1.61it/s]

Loss: 5.23636919260025


 22%|██▎       | 9/40 [00:05<00:19,  1.61it/s]

Loss: 4.91819441318512


 25%|██▌       | 10/40 [00:06<00:18,  1.61it/s]

Loss: 4.706940680742264


 28%|██▊       | 11/40 [00:06<00:18,  1.61it/s]

Loss: 4.424648433923721


 30%|███       | 12/40 [00:07<00:17,  1.61it/s]

Loss: 4.219170570373535


 32%|███▎      | 13/40 [00:08<00:16,  1.61it/s]

Loss: 3.87447053194046


 35%|███▌      | 14/40 [00:08<00:16,  1.61it/s]

Loss: 3.569759115576744


 38%|███▊      | 15/40 [00:09<00:15,  1.61it/s]

Loss: 3.2664838284254074


 40%|████      | 16/40 [00:09<00:14,  1.61it/s]

Loss: 2.7742929458618164


 42%|████▎     | 17/40 [00:10<00:14,  1.61it/s]

Loss: 2.459540292620659


 45%|████▌     | 18/40 [00:11<00:13,  1.61it/s]

Loss: 2.2125339657068253


 48%|████▊     | 19/40 [00:11<00:13,  1.62it/s]

Loss: 2.0274029225111008


 50%|█████     | 20/40 [00:12<00:12,  1.61it/s]

Loss: 1.8359735310077667


 52%|█████▎    | 21/40 [00:13<00:11,  1.61it/s]

Loss: 1.7231449484825134


 55%|█████▌    | 22/40 [00:13<00:11,  1.61it/s]

Loss: 1.6198662668466568


 57%|█████▊    | 23/40 [00:14<00:10,  1.61it/s]

Loss: 1.5800202190876007


 60%|██████    | 24/40 [00:14<00:09,  1.61it/s]

Loss: 1.4641465321183205


 62%|██████▎   | 25/40 [00:15<00:09,  1.61it/s]

Loss: 1.4393890090286732


 65%|██████▌   | 26/40 [00:16<00:08,  1.61it/s]

Loss: 1.4183744713664055


 68%|██████▊   | 27/40 [00:16<00:08,  1.61it/s]

Loss: 1.5033145025372505


 70%|███████   | 28/40 [00:17<00:07,  1.61it/s]

Loss: 1.2837935760617256


 72%|███████▎  | 29/40 [00:17<00:06,  1.61it/s]

Loss: 1.2746476791799068


 75%|███████▌  | 30/40 [00:18<00:06,  1.61it/s]

Loss: 1.275822963565588


 78%|███████▊  | 31/40 [00:19<00:05,  1.61it/s]

Loss: 1.2262422479689121


 80%|████████  | 32/40 [00:19<00:04,  1.61it/s]

Loss: 1.2176598906517029


 82%|████████▎ | 33/40 [00:20<00:04,  1.61it/s]

Loss: 1.2981316782534122


 85%|████████▌ | 34/40 [00:21<00:03,  1.61it/s]

Loss: 1.2453259695321321


 88%|████████▊ | 35/40 [00:21<00:03,  1.61it/s]

Loss: 1.179055567830801


 90%|█████████ | 36/40 [00:22<00:02,  1.61it/s]

Loss: 1.1363373324275017


 92%|█████████▎| 37/40 [00:22<00:01,  1.61it/s]

Loss: 1.212726190686226


 95%|█████████▌| 38/40 [00:23<00:01,  1.61it/s]

Loss: 1.0722387954592705


 98%|█████████▊| 39/40 [00:24<00:00,  1.61it/s]

Loss: 1.0382212437689304


100%|██████████| 40/40 [00:24<00:00,  1.61it/s]
  "num_layers={}".format(dropout, num_layers))


Loss: 1.048129478469491
[7.262626349925995, 6.51101142168045, 6.324973165988922, 6.164837598800659, 5.9258571565151215, 5.53159099817276, 5.404061734676361, 5.23636919260025, 4.91819441318512, 4.706940680742264, 4.424648433923721, 4.219170570373535, 3.87447053194046, 3.569759115576744, 3.2664838284254074, 2.7742929458618164, 2.459540292620659, 2.2125339657068253, 2.0274029225111008, 1.8359735310077667, 1.7231449484825134, 1.6198662668466568, 1.5800202190876007, 1.4641465321183205, 1.4393890090286732, 1.4183744713664055, 1.5033145025372505, 1.2837935760617256, 1.2746476791799068, 1.275822963565588, 1.2262422479689121, 1.2176598906517029, 1.2981316782534122, 1.2453259695321321, 1.179055567830801, 1.1363373324275017, 1.212726190686226, 1.0722387954592705, 1.0382212437689304, 1.048129478469491]
1.1541446447372437
Accuracy: 0.743421052631579
Precision, Recall, F1: (0.59375, 0.4222222222222222, 0.49350649350649345, None)


  2%|▎         | 1/40 [00:00<00:22,  1.73it/s]

Loss: 7.213360965251923


  5%|▌         | 2/40 [00:01<00:22,  1.71it/s]

Loss: 6.5478861927986145


  8%|▊         | 3/40 [00:01<00:21,  1.70it/s]

Loss: 6.380938708782196


 10%|█         | 4/40 [00:02<00:21,  1.70it/s]

Loss: 6.2023550271987915


 12%|█▎        | 5/40 [00:02<00:20,  1.69it/s]

Loss: 6.125727832317352


 15%|█▌        | 6/40 [00:03<00:19,  1.70it/s]

Loss: 5.880303204059601


 18%|█▊        | 7/40 [00:04<00:19,  1.69it/s]

Loss: 5.463090270757675


 20%|██        | 8/40 [00:04<00:18,  1.70it/s]

Loss: 5.242991507053375


 22%|██▎       | 9/40 [00:05<00:18,  1.70it/s]

Loss: 4.981876105070114


 25%|██▌       | 10/40 [00:05<00:17,  1.71it/s]

Loss: 4.767012417316437


 28%|██▊       | 11/40 [00:06<00:17,  1.70it/s]

Loss: 4.497995764017105


 30%|███       | 12/40 [00:07<00:16,  1.70it/s]

Loss: 4.160667568445206


 32%|███▎      | 13/40 [00:07<00:15,  1.70it/s]

Loss: 3.7858581840991974


 35%|███▌      | 14/40 [00:08<00:15,  1.70it/s]

Loss: 3.2803140580654144


 38%|███▊      | 15/40 [00:08<00:14,  1.70it/s]

Loss: 2.8988475501537323


 40%|████      | 16/40 [00:09<00:14,  1.70it/s]

Loss: 2.5680477619171143


 42%|████▎     | 17/40 [00:09<00:13,  1.70it/s]

Loss: 2.4233864098787308


 45%|████▌     | 18/40 [00:10<00:12,  1.70it/s]

Loss: 2.1195022016763687


 48%|████▊     | 19/40 [00:11<00:12,  1.70it/s]

Loss: 1.9224224984645844


 50%|█████     | 20/40 [00:11<00:11,  1.70it/s]

Loss: 1.7807263359427452


 52%|█████▎    | 21/40 [00:12<00:11,  1.70it/s]

Loss: 1.6307114213705063


 55%|█████▌    | 22/40 [00:12<00:10,  1.70it/s]

Loss: 1.708483088761568


 57%|█████▊    | 23/40 [00:13<00:10,  1.70it/s]

Loss: 1.511303961277008


 60%|██████    | 24/40 [00:14<00:09,  1.70it/s]

Loss: 1.4262320287525654


 62%|██████▎   | 25/40 [00:14<00:08,  1.70it/s]

Loss: 1.353358794003725


 65%|██████▌   | 26/40 [00:15<00:08,  1.70it/s]

Loss: 1.3214405290782452


 68%|██████▊   | 27/40 [00:15<00:07,  1.70it/s]

Loss: 1.3352267667651176


 70%|███████   | 28/40 [00:16<00:07,  1.70it/s]

Loss: 1.269742637872696


 72%|███████▎  | 29/40 [00:17<00:06,  1.70it/s]

Loss: 1.2287529110908508


 75%|███████▌  | 30/40 [00:17<00:05,  1.70it/s]

Loss: 1.2146244421601295


 78%|███████▊  | 31/40 [00:18<00:05,  1.70it/s]

Loss: 1.1915881074965


 80%|████████  | 32/40 [00:18<00:04,  1.70it/s]

Loss: 1.1581514775753021


 82%|████████▎ | 33/40 [00:19<00:04,  1.70it/s]

Loss: 1.177031997591257


 85%|████████▌ | 34/40 [00:20<00:03,  1.70it/s]

Loss: 1.135086752474308


 88%|████████▊ | 35/40 [00:20<00:02,  1.70it/s]

Loss: 1.1613508872687817


 90%|█████████ | 36/40 [00:21<00:02,  1.70it/s]

Loss: 1.1123925484716892


 92%|█████████▎| 37/40 [00:21<00:01,  1.70it/s]

Loss: 1.0937504805624485


 95%|█████████▌| 38/40 [00:22<00:01,  1.70it/s]

Loss: 1.0364702586084604


 98%|█████████▊| 39/40 [00:22<00:00,  1.70it/s]

Loss: 1.037808796390891


100%|██████████| 40/40 [00:23<00:00,  1.70it/s]

Loss: 1.0532821901142597
[7.213360965251923, 6.5478861927986145, 6.380938708782196, 6.2023550271987915, 6.125727832317352, 5.880303204059601, 5.463090270757675, 5.242991507053375, 4.981876105070114, 4.767012417316437, 4.497995764017105, 4.160667568445206, 3.7858581840991974, 3.2803140580654144, 2.8988475501537323, 2.5680477619171143, 2.4233864098787308, 2.1195022016763687, 1.9224224984645844, 1.7807263359427452, 1.6307114213705063, 1.708483088761568, 1.511303961277008, 1.4262320287525654, 1.353358794003725, 1.3214405290782452, 1.3352267667651176, 1.269742637872696, 1.2287529110908508, 1.2146244421601295, 1.1915881074965, 1.1581514775753021, 1.177031997591257, 1.135086752474308, 1.1613508872687817, 1.1123925484716892, 1.0937504805624485, 1.0364702586084604, 1.037808796390891, 1.0532821901142597]
1.5970202684402466
Accuracy: 0.8092105263157895
Precision, Recall, F1: (0.6, 0.5833333333333334, 0.591549295774648, None)



  "num_layers={}".format(dropout, num_layers))
  2%|▎         | 1/40 [00:00<00:23,  1.63it/s]

Loss: 7.276293337345123


  5%|▌         | 2/40 [00:01<00:23,  1.63it/s]

Loss: 6.48140025138855


  8%|▊         | 3/40 [00:01<00:22,  1.62it/s]

Loss: 6.339214026927948


 10%|█         | 4/40 [00:02<00:22,  1.62it/s]

Loss: 6.144001692533493


 12%|█▎        | 5/40 [00:03<00:21,  1.62it/s]

Loss: 5.93529537320137


 15%|█▌        | 6/40 [00:03<00:20,  1.62it/s]

Loss: 5.7003160417079926


 18%|█▊        | 7/40 [00:04<00:20,  1.62it/s]

Loss: 5.385420739650726


 20%|██        | 8/40 [00:04<00:19,  1.63it/s]

Loss: 5.172050356864929


 22%|██▎       | 9/40 [00:05<00:19,  1.63it/s]

Loss: 4.88698798418045


 25%|██▌       | 10/40 [00:06<00:18,  1.63it/s]

Loss: 4.589141458272934


 28%|██▊       | 11/40 [00:06<00:17,  1.63it/s]

Loss: 4.297521084547043


 30%|███       | 12/40 [00:07<00:17,  1.63it/s]

Loss: 3.8748326897621155


 32%|███▎      | 13/40 [00:07<00:16,  1.63it/s]

Loss: 3.6040701270103455


 35%|███▌      | 14/40 [00:08<00:15,  1.64it/s]

Loss: 3.1762405037879944


 38%|███▊      | 15/40 [00:09<00:15,  1.63it/s]

Loss: 2.7358577996492386


 40%|████      | 16/40 [00:09<00:14,  1.63it/s]

Loss: 2.4312907457351685


 42%|████▎     | 17/40 [00:10<00:14,  1.63it/s]

Loss: 2.2664507180452347


 45%|████▌     | 18/40 [00:11<00:13,  1.62it/s]

Loss: 2.1409720703959465


 48%|████▊     | 19/40 [00:11<00:12,  1.63it/s]

Loss: 1.9508605599403381


 50%|█████     | 20/40 [00:12<00:12,  1.63it/s]

Loss: 1.8819091692566872


 52%|█████▎    | 21/40 [00:12<00:11,  1.63it/s]

Loss: 1.759774312376976


 55%|█████▌    | 22/40 [00:13<00:11,  1.63it/s]

Loss: 1.542098082602024


 57%|█████▊    | 23/40 [00:14<00:10,  1.63it/s]

Loss: 1.5516751557588577


 60%|██████    | 24/40 [00:14<00:09,  1.63it/s]

Loss: 1.4584157951176167


 62%|██████▎   | 25/40 [00:15<00:09,  1.63it/s]

Loss: 1.466941386461258


 65%|██████▌   | 26/40 [00:15<00:08,  1.62it/s]

Loss: 1.3829194977879524


 68%|██████▊   | 27/40 [00:16<00:07,  1.63it/s]

Loss: 1.3302155956625938


 70%|███████   | 28/40 [00:17<00:07,  1.63it/s]

Loss: 1.2875746935606003


 72%|███████▎  | 29/40 [00:17<00:06,  1.63it/s]

Loss: 1.269763745367527


 75%|███████▌  | 30/40 [00:18<00:06,  1.63it/s]

Loss: 1.2353941723704338


 78%|███████▊  | 31/40 [00:19<00:05,  1.63it/s]

Loss: 1.2375153228640556


 80%|████████  | 32/40 [00:19<00:04,  1.63it/s]

Loss: 1.2177358269691467


 82%|████████▎ | 33/40 [00:20<00:04,  1.63it/s]

Loss: 1.207899548113346


 85%|████████▌ | 34/40 [00:20<00:03,  1.62it/s]

Loss: 1.1545945908874273


 88%|████████▊ | 35/40 [00:21<00:03,  1.62it/s]

Loss: 1.1472344510257244


 90%|█████████ | 36/40 [00:22<00:02,  1.62it/s]

Loss: 1.1394381001591682


 92%|█████████▎| 37/40 [00:22<00:01,  1.62it/s]

Loss: 1.1321498528122902


 95%|█████████▌| 38/40 [00:23<00:01,  1.62it/s]

Loss: 1.0951354131102562


 98%|█████████▊| 39/40 [00:23<00:00,  1.62it/s]

Loss: 1.093870848417282


100%|██████████| 40/40 [00:24<00:00,  1.63it/s]
  "num_layers={}".format(dropout, num_layers))


Loss: 1.0769252628087997
[7.276293337345123, 6.48140025138855, 6.339214026927948, 6.144001692533493, 5.93529537320137, 5.7003160417079926, 5.385420739650726, 5.172050356864929, 4.88698798418045, 4.589141458272934, 4.297521084547043, 3.8748326897621155, 3.6040701270103455, 3.1762405037879944, 2.7358577996492386, 2.4312907457351685, 2.2664507180452347, 2.1409720703959465, 1.9508605599403381, 1.8819091692566872, 1.759774312376976, 1.542098082602024, 1.5516751557588577, 1.4584157951176167, 1.466941386461258, 1.3829194977879524, 1.3302155956625938, 1.2875746935606003, 1.269763745367527, 1.2353941723704338, 1.2375153228640556, 1.2177358269691467, 1.207899548113346, 1.1545945908874273, 1.1472344510257244, 1.1394381001591682, 1.1321498528122902, 1.0951354131102562, 1.093870848417282, 1.0769252628087997]
1.2199856042861938
Accuracy: 0.7236842105263158
Precision, Recall, F1: (0.6071428571428571, 0.3541666666666667, 0.4473684210526316, None)


  2%|▎         | 1/40 [00:00<00:24,  1.59it/s]

Loss: 7.215121924877167


  5%|▌         | 2/40 [00:01<00:23,  1.58it/s]

Loss: 6.537576198577881


  8%|▊         | 3/40 [00:01<00:23,  1.58it/s]

Loss: 6.402123749256134


 10%|█         | 4/40 [00:02<00:22,  1.58it/s]

Loss: 6.208767890930176


 12%|█▎        | 5/40 [00:03<00:22,  1.59it/s]

Loss: 5.94431135058403


 15%|█▌        | 6/40 [00:03<00:21,  1.59it/s]

Loss: 5.659127444028854


 18%|█▊        | 7/40 [00:04<00:20,  1.60it/s]

Loss: 5.354546934366226


 20%|██        | 8/40 [00:05<00:20,  1.59it/s]

Loss: 5.128549873828888


 22%|██▎       | 9/40 [00:05<00:19,  1.60it/s]

Loss: 4.935650557279587


 25%|██▌       | 10/40 [00:06<00:18,  1.60it/s]

Loss: 4.689257383346558


 28%|██▊       | 11/40 [00:06<00:18,  1.60it/s]

Loss: 4.566861778497696


 30%|███       | 12/40 [00:07<00:17,  1.60it/s]

Loss: 4.29443496465683


 32%|███▎      | 13/40 [00:08<00:16,  1.60it/s]

Loss: 3.8347094655036926


 35%|███▌      | 14/40 [00:08<00:16,  1.60it/s]

Loss: 3.4129328429698944


 38%|███▊      | 15/40 [00:09<00:15,  1.60it/s]

Loss: 3.0276568681001663


 40%|████      | 16/40 [00:10<00:15,  1.59it/s]

Loss: 2.6384926587343216


 42%|████▎     | 17/40 [00:10<00:14,  1.60it/s]

Loss: 2.3059993535280228


 45%|████▌     | 18/40 [00:11<00:13,  1.60it/s]

Loss: 2.077969826757908


 48%|████▊     | 19/40 [00:11<00:13,  1.60it/s]

Loss: 1.9586023315787315


 50%|█████     | 20/40 [00:12<00:12,  1.60it/s]

Loss: 1.7289241626858711


 52%|█████▎    | 21/40 [00:13<00:11,  1.59it/s]

Loss: 1.5239313021302223


 55%|█████▌    | 22/40 [00:13<00:11,  1.60it/s]

Loss: 1.4406364969909191


 57%|█████▊    | 23/40 [00:14<00:10,  1.60it/s]

Loss: 1.3266240432858467


 60%|██████    | 24/40 [00:15<00:10,  1.60it/s]

Loss: 1.29245600476861


 62%|██████▎   | 25/40 [00:15<00:09,  1.59it/s]

Loss: 1.224849358201027


 65%|██████▌   | 26/40 [00:16<00:08,  1.59it/s]

Loss: 1.207428090274334


 68%|██████▊   | 27/40 [00:16<00:08,  1.59it/s]

Loss: 1.1895238272845745


 70%|███████   | 28/40 [00:17<00:07,  1.59it/s]

Loss: 1.1649752706289291


 72%|███████▎  | 29/40 [00:18<00:06,  1.59it/s]

Loss: 1.1156844813376665


 75%|███████▌  | 30/40 [00:18<00:06,  1.60it/s]

Loss: 1.129247473552823


 78%|███████▊  | 31/40 [00:19<00:05,  1.59it/s]

Loss: 1.091505702584982


 80%|████████  | 32/40 [00:20<00:05,  1.59it/s]

Loss: 1.039744008332491


 82%|████████▎ | 33/40 [00:20<00:04,  1.60it/s]

Loss: 0.9660644493997097


 85%|████████▌ | 34/40 [00:21<00:03,  1.60it/s]

Loss: 1.000022366642952


 88%|████████▊ | 35/40 [00:21<00:03,  1.60it/s]

Loss: 1.001721616834402


 90%|█████████ | 36/40 [00:22<00:02,  1.60it/s]

Loss: 1.2224850319325924


 92%|█████████▎| 37/40 [00:23<00:01,  1.60it/s]

Loss: 1.0342114083468914


 95%|█████████▌| 38/40 [00:23<00:01,  1.60it/s]

Loss: 1.2429446540772915


 98%|█████████▊| 39/40 [00:24<00:00,  1.60it/s]

Loss: 1.301542017608881


100%|██████████| 40/40 [00:25<00:00,  1.60it/s]
  "num_layers={}".format(dropout, num_layers))


Loss: 1.034809947013855
[7.215121924877167, 6.537576198577881, 6.402123749256134, 6.208767890930176, 5.94431135058403, 5.659127444028854, 5.354546934366226, 5.128549873828888, 4.935650557279587, 4.689257383346558, 4.566861778497696, 4.29443496465683, 3.8347094655036926, 3.4129328429698944, 3.0276568681001663, 2.6384926587343216, 2.3059993535280228, 2.077969826757908, 1.9586023315787315, 1.7289241626858711, 1.5239313021302223, 1.4406364969909191, 1.3266240432858467, 1.29245600476861, 1.224849358201027, 1.207428090274334, 1.1895238272845745, 1.1649752706289291, 1.1156844813376665, 1.129247473552823, 1.091505702584982, 1.039744008332491, 0.9660644493997097, 1.000022366642952, 1.001721616834402, 1.2224850319325924, 1.0342114083468914, 1.2429446540772915, 1.301542017608881, 1.034809947013855]
1.1285547018051147
Accuracy: 0.7302631578947368
Precision, Recall, F1: (0.525, 0.4883720930232558, 0.5060240963855422, None)


  2%|▎         | 1/40 [00:00<00:24,  1.61it/s]

Loss: 6.9864736795425415


  5%|▌         | 2/40 [00:01<00:23,  1.61it/s]

Loss: 6.616180062294006


  8%|▊         | 3/40 [00:01<00:22,  1.61it/s]

Loss: 6.4628554582595825


 10%|█         | 4/40 [00:02<00:22,  1.62it/s]

Loss: 6.365916609764099


 12%|█▎        | 5/40 [00:03<00:21,  1.62it/s]

Loss: 6.147578984498978


 15%|█▌        | 6/40 [00:03<00:20,  1.62it/s]

Loss: 5.940781116485596


 18%|█▊        | 7/40 [00:04<00:20,  1.62it/s]

Loss: 5.6576985120773315


 20%|██        | 8/40 [00:04<00:19,  1.62it/s]

Loss: 5.507376194000244


 22%|██▎       | 9/40 [00:05<00:19,  1.63it/s]

Loss: 5.279149204492569


 25%|██▌       | 10/40 [00:06<00:18,  1.62it/s]

Loss: 5.007588803768158


 28%|██▊       | 11/40 [00:06<00:17,  1.62it/s]

Loss: 4.748581022024155


 30%|███       | 12/40 [00:07<00:17,  1.63it/s]

Loss: 4.387487322092056


 32%|███▎      | 13/40 [00:08<00:16,  1.62it/s]

Loss: 3.901605322957039


 35%|███▌      | 14/40 [00:08<00:16,  1.62it/s]

Loss: 3.624173626303673


 38%|███▊      | 15/40 [00:09<00:15,  1.62it/s]

Loss: 3.191980004310608


 40%|████      | 16/40 [00:09<00:14,  1.62it/s]

Loss: 2.923292562365532


 42%|████▎     | 17/40 [00:10<00:14,  1.62it/s]

Loss: 2.8649659156799316


 45%|████▌     | 18/40 [00:11<00:13,  1.62it/s]

Loss: 2.466131716966629


 48%|████▊     | 19/40 [00:11<00:12,  1.62it/s]

Loss: 2.2168161123991013


 50%|█████     | 20/40 [00:12<00:12,  1.62it/s]

Loss: 2.0732851326465607


 52%|█████▎    | 21/40 [00:12<00:11,  1.62it/s]

Loss: 1.9450483918190002


 55%|█████▌    | 22/40 [00:13<00:11,  1.62it/s]

Loss: 1.8918826282024384


 57%|█████▊    | 23/40 [00:14<00:10,  1.62it/s]

Loss: 1.7533039227128029


 60%|██████    | 24/40 [00:14<00:09,  1.62it/s]

Loss: 1.6366291493177414


 62%|██████▎   | 25/40 [00:15<00:09,  1.62it/s]

Loss: 1.5559233650565147


 65%|██████▌   | 26/40 [00:16<00:08,  1.62it/s]

Loss: 1.5339264422655106


 68%|██████▊   | 27/40 [00:16<00:08,  1.62it/s]

Loss: 1.4834337681531906


 70%|███████   | 28/40 [00:17<00:07,  1.62it/s]

Loss: 1.4006017670035362


 72%|███████▎  | 29/40 [00:17<00:06,  1.62it/s]

Loss: 1.4646856561303139


 75%|███████▌  | 30/40 [00:18<00:06,  1.62it/s]

Loss: 1.3619816713035107


 78%|███████▊  | 31/40 [00:19<00:05,  1.62it/s]

Loss: 1.4012693166732788


 80%|████████  | 32/40 [00:19<00:04,  1.62it/s]

Loss: 1.3672458305954933


 82%|████████▎ | 33/40 [00:20<00:04,  1.62it/s]

Loss: 1.3227479681372643


 85%|████████▌ | 34/40 [00:20<00:03,  1.62it/s]

Loss: 1.300270564854145


 88%|████████▊ | 35/40 [00:21<00:03,  1.62it/s]

Loss: 1.2587889209389687


 90%|█████████ | 36/40 [00:22<00:02,  1.62it/s]

Loss: 1.2208088636398315


 92%|█████████▎| 37/40 [00:22<00:01,  1.62it/s]

Loss: 1.19971589371562


 95%|█████████▌| 38/40 [00:23<00:01,  1.62it/s]

Loss: 1.1836632043123245


 98%|█████████▊| 39/40 [00:24<00:00,  1.62it/s]

Loss: 1.1623235121369362


100%|██████████| 40/40 [00:24<00:00,  1.62it/s]
  "num_layers={}".format(dropout, num_layers))


Loss: 1.1567716151475906
[6.9864736795425415, 6.616180062294006, 6.4628554582595825, 6.365916609764099, 6.147578984498978, 5.940781116485596, 5.6576985120773315, 5.507376194000244, 5.279149204492569, 5.007588803768158, 4.748581022024155, 4.387487322092056, 3.901605322957039, 3.624173626303673, 3.191980004310608, 2.923292562365532, 2.8649659156799316, 2.466131716966629, 2.2168161123991013, 2.0732851326465607, 1.9450483918190002, 1.8918826282024384, 1.7533039227128029, 1.6366291493177414, 1.5559233650565147, 1.5339264422655106, 1.4834337681531906, 1.4006017670035362, 1.4646856561303139, 1.3619816713035107, 1.4012693166732788, 1.3672458305954933, 1.3227479681372643, 1.300270564854145, 1.2587889209389687, 1.2208088636398315, 1.19971589371562, 1.1836632043123245, 1.1623235121369362, 1.1567716151475906]
0.8912470936775208
Accuracy: 0.8223684210526315
Precision, Recall, F1: (0.6428571428571429, 0.6923076923076923, 0.6666666666666666, None)


  2%|▎         | 1/40 [00:00<00:23,  1.64it/s]

Loss: 6.8659244775772095


  5%|▌         | 2/40 [00:01<00:23,  1.65it/s]

Loss: 6.4419591426849365


  8%|▊         | 3/40 [00:01<00:22,  1.65it/s]

Loss: 6.237116754055023


 10%|█         | 4/40 [00:02<00:21,  1.65it/s]

Loss: 6.217990398406982


 12%|█▎        | 5/40 [00:03<00:21,  1.66it/s]

Loss: 6.065732955932617


 15%|█▌        | 6/40 [00:03<00:20,  1.66it/s]

Loss: 5.870999902486801


 18%|█▊        | 7/40 [00:04<00:19,  1.65it/s]

Loss: 5.515024870634079


 20%|██        | 8/40 [00:04<00:19,  1.65it/s]

Loss: 5.27277284860611


 22%|██▎       | 9/40 [00:05<00:18,  1.65it/s]

Loss: 5.057936131954193


 25%|██▌       | 10/40 [00:06<00:18,  1.65it/s]

Loss: 4.92423751950264


 28%|██▊       | 11/40 [00:06<00:17,  1.64it/s]

Loss: 4.668417304754257


 30%|███       | 12/40 [00:07<00:17,  1.64it/s]

Loss: 4.375585973262787


 32%|███▎      | 13/40 [00:07<00:16,  1.64it/s]

Loss: 4.110835164785385


 35%|███▌      | 14/40 [00:08<00:15,  1.65it/s]

Loss: 3.808676689863205


 38%|███▊      | 15/40 [00:09<00:15,  1.65it/s]

Loss: 3.339978978037834


 40%|████      | 16/40 [00:09<00:14,  1.65it/s]

Loss: 3.142683506011963


 42%|████▎     | 17/40 [00:10<00:13,  1.65it/s]

Loss: 2.741003766655922


 45%|████▌     | 18/40 [00:10<00:13,  1.66it/s]

Loss: 2.3798591643571854


 48%|████▊     | 19/40 [00:11<00:12,  1.66it/s]

Loss: 2.294387623667717


 50%|█████     | 20/40 [00:12<00:12,  1.65it/s]

Loss: 2.1827442348003387


 52%|█████▎    | 21/40 [00:12<00:11,  1.65it/s]

Loss: 2.0709895938634872


 55%|█████▌    | 22/40 [00:13<00:10,  1.65it/s]

Loss: 2.101112872362137


 57%|█████▊    | 23/40 [00:13<00:10,  1.65it/s]

Loss: 1.8501005172729492


 60%|██████    | 24/40 [00:14<00:09,  1.65it/s]

Loss: 1.584350861608982


 62%|██████▎   | 25/40 [00:15<00:09,  1.66it/s]

Loss: 1.5161267220973969


 65%|██████▌   | 26/40 [00:15<00:08,  1.65it/s]

Loss: 1.4453900530934334


 68%|██████▊   | 27/40 [00:16<00:07,  1.65it/s]

Loss: 1.3697500452399254


 70%|███████   | 28/40 [00:16<00:07,  1.65it/s]

Loss: 1.3043201640248299


 72%|███████▎  | 29/40 [00:17<00:06,  1.65it/s]

Loss: 1.2668139189481735


 75%|███████▌  | 30/40 [00:18<00:06,  1.65it/s]

Loss: 1.2105571031570435


 78%|███████▊  | 31/40 [00:18<00:05,  1.65it/s]

Loss: 1.2452741228044033


 80%|████████  | 32/40 [00:19<00:04,  1.65it/s]

Loss: 1.2504811882972717


 82%|████████▎ | 33/40 [00:19<00:04,  1.65it/s]

Loss: 1.2821625173091888


 85%|████████▌ | 34/40 [00:20<00:03,  1.65it/s]

Loss: 1.2146298736333847


 88%|████████▊ | 35/40 [00:21<00:03,  1.65it/s]

Loss: 1.2475780621170998


 90%|█████████ | 36/40 [00:21<00:02,  1.66it/s]

Loss: 1.2469500415027142


 92%|█████████▎| 37/40 [00:22<00:01,  1.66it/s]

Loss: 1.1595060210675001


 95%|█████████▌| 38/40 [00:23<00:01,  1.66it/s]

Loss: 1.1173616889864206


 98%|█████████▊| 39/40 [00:23<00:00,  1.66it/s]

Loss: 1.1402573138475418


100%|██████████| 40/40 [00:24<00:00,  1.65it/s]
  "num_layers={}".format(dropout, num_layers))


Loss: 1.1298575438559055
[6.8659244775772095, 6.4419591426849365, 6.237116754055023, 6.217990398406982, 6.065732955932617, 5.870999902486801, 5.515024870634079, 5.27277284860611, 5.057936131954193, 4.92423751950264, 4.668417304754257, 4.375585973262787, 4.110835164785385, 3.808676689863205, 3.339978978037834, 3.142683506011963, 2.741003766655922, 2.3798591643571854, 2.294387623667717, 2.1827442348003387, 2.0709895938634872, 2.101112872362137, 1.8501005172729492, 1.584350861608982, 1.5161267220973969, 1.4453900530934334, 1.3697500452399254, 1.3043201640248299, 1.2668139189481735, 1.2105571031570435, 1.2452741228044033, 1.2504811882972717, 1.2821625173091888, 1.2146298736333847, 1.2475780621170998, 1.2469500415027142, 1.1595060210675001, 1.1173616889864206, 1.1402573138475418, 1.1298575438559055]
1.7341598272323608
Accuracy: 0.7236842105263158
Precision, Recall, F1: (0.6388888888888888, 0.4423076923076923, 0.5227272727272726, None)


  2%|▎         | 1/40 [00:00<00:24,  1.58it/s]

Loss: 7.046912848949432


  5%|▌         | 2/40 [00:01<00:24,  1.57it/s]

Loss: 6.544098377227783


  8%|▊         | 3/40 [00:01<00:23,  1.57it/s]

Loss: 6.3838348388671875


 10%|█         | 4/40 [00:02<00:22,  1.57it/s]

Loss: 6.2269163727760315


 12%|█▎        | 5/40 [00:03<00:22,  1.56it/s]

Loss: 6.037216603755951


 15%|█▌        | 6/40 [00:03<00:21,  1.56it/s]

Loss: 5.651845097541809


 18%|█▊        | 7/40 [00:04<00:21,  1.57it/s]

Loss: 5.319203227758408


 20%|██        | 8/40 [00:05<00:20,  1.57it/s]

Loss: 5.082265019416809


 22%|██▎       | 9/40 [00:05<00:19,  1.57it/s]

Loss: 4.989072561264038


 25%|██▌       | 10/40 [00:06<00:19,  1.58it/s]

Loss: 4.760319024324417


 28%|██▊       | 11/40 [00:07<00:18,  1.57it/s]

Loss: 4.4243727922439575


 30%|███       | 12/40 [00:07<00:17,  1.57it/s]

Loss: 4.14660257101059


 32%|███▎      | 13/40 [00:08<00:17,  1.57it/s]

Loss: 3.5735769271850586


 35%|███▌      | 14/40 [00:08<00:16,  1.57it/s]

Loss: 3.1570314168930054


 38%|███▊      | 15/40 [00:09<00:15,  1.58it/s]

Loss: 2.7263012677431107


 40%|████      | 16/40 [00:10<00:15,  1.58it/s]

Loss: 2.332012951374054


 42%|████▎     | 17/40 [00:10<00:14,  1.57it/s]

Loss: 2.1372210681438446


 45%|████▌     | 18/40 [00:11<00:13,  1.57it/s]

Loss: 1.9302217066287994


 48%|████▊     | 19/40 [00:12<00:13,  1.57it/s]

Loss: 1.8286557123064995


 50%|█████     | 20/40 [00:12<00:12,  1.57it/s]

Loss: 1.6503037437796593


 52%|█████▎    | 21/40 [00:13<00:12,  1.57it/s]

Loss: 1.4832249097526073


 55%|█████▌    | 22/40 [00:14<00:11,  1.57it/s]

Loss: 1.3822984248399734


 57%|█████▊    | 23/40 [00:14<00:10,  1.57it/s]

Loss: 1.2781204581260681


 60%|██████    | 24/40 [00:15<00:10,  1.57it/s]

Loss: 1.2108925133943558


 62%|██████▎   | 25/40 [00:15<00:09,  1.56it/s]

Loss: 1.2485051304101944


 65%|██████▌   | 26/40 [00:16<00:08,  1.57it/s]

Loss: 1.172810323536396


 68%|██████▊   | 27/40 [00:17<00:08,  1.58it/s]

Loss: 1.135332863777876


 70%|███████   | 28/40 [00:17<00:07,  1.58it/s]

Loss: 1.2030004691332579


 72%|███████▎  | 29/40 [00:18<00:06,  1.58it/s]

Loss: 1.108421940356493


 75%|███████▌  | 30/40 [00:19<00:06,  1.58it/s]

Loss: 1.0829537138342857


 78%|███████▊  | 31/40 [00:19<00:05,  1.57it/s]

Loss: 1.0175336115062237


 80%|████████  | 32/40 [00:20<00:05,  1.58it/s]

Loss: 0.9862884134054184


 82%|████████▎ | 33/40 [00:20<00:04,  1.57it/s]

Loss: 0.9779583923518658


 85%|████████▌ | 34/40 [00:21<00:03,  1.57it/s]

Loss: 0.9593494646251202


 88%|████████▊ | 35/40 [00:22<00:03,  1.57it/s]

Loss: 0.967224245890975


 90%|█████████ | 36/40 [00:22<00:02,  1.58it/s]

Loss: 0.8950511440634727


 92%|█████████▎| 37/40 [00:23<00:01,  1.58it/s]

Loss: 0.9019747450947762


 95%|█████████▌| 38/40 [00:24<00:01,  1.57it/s]

Loss: 0.8733561113476753


 98%|█████████▊| 39/40 [00:24<00:00,  1.57it/s]

Loss: 0.8709127716720104


100%|██████████| 40/40 [00:25<00:00,  1.57it/s]
  "num_layers={}".format(dropout, num_layers))


Loss: 1.1053336970508099
[7.046912848949432, 6.544098377227783, 6.3838348388671875, 6.2269163727760315, 6.037216603755951, 5.651845097541809, 5.319203227758408, 5.082265019416809, 4.989072561264038, 4.760319024324417, 4.4243727922439575, 4.14660257101059, 3.5735769271850586, 3.1570314168930054, 2.7263012677431107, 2.332012951374054, 2.1372210681438446, 1.9302217066287994, 1.8286557123064995, 1.6503037437796593, 1.4832249097526073, 1.3822984248399734, 1.2781204581260681, 1.2108925133943558, 1.2485051304101944, 1.172810323536396, 1.135332863777876, 1.2030004691332579, 1.108421940356493, 1.0829537138342857, 1.0175336115062237, 0.9862884134054184, 0.9779583923518658, 0.9593494646251202, 0.967224245890975, 0.8950511440634727, 0.9019747450947762, 0.8733561113476753, 0.8709127716720104, 1.1053336970508099]
1.3011761903762817
Accuracy: 0.7236842105263158
Precision, Recall, F1: (0.38095238095238093, 0.21621621621621623, 0.27586206896551724, None)


  2%|▎         | 1/40 [00:00<00:24,  1.60it/s]

Loss: 7.197408378124237


  5%|▌         | 2/40 [00:01<00:23,  1.60it/s]

Loss: 6.583252012729645


  8%|▊         | 3/40 [00:01<00:23,  1.60it/s]

Loss: 6.427386522293091


 10%|█         | 4/40 [00:02<00:22,  1.60it/s]

Loss: 6.262961506843567


 12%|█▎        | 5/40 [00:03<00:21,  1.60it/s]

Loss: 5.94216114282608


 15%|█▌        | 6/40 [00:03<00:21,  1.60it/s]

Loss: 5.541614472866058


 18%|█▊        | 7/40 [00:04<00:20,  1.61it/s]

Loss: 5.286610275506973


 20%|██        | 8/40 [00:04<00:19,  1.60it/s]

Loss: 5.107948929071426


 22%|██▎       | 9/40 [00:05<00:19,  1.60it/s]

Loss: 4.917902588844299


 25%|██▌       | 10/40 [00:06<00:18,  1.60it/s]

Loss: 4.756840378046036


 28%|██▊       | 11/40 [00:06<00:18,  1.59it/s]

Loss: 4.490712910890579


 30%|███       | 12/40 [00:07<00:17,  1.60it/s]

Loss: 4.2222690880298615


 32%|███▎      | 13/40 [00:08<00:16,  1.60it/s]

Loss: 3.8480521142482758


 35%|███▌      | 14/40 [00:08<00:16,  1.60it/s]

Loss: 3.436291068792343


 38%|███▊      | 15/40 [00:09<00:15,  1.60it/s]

Loss: 3.1411882638931274


 40%|████      | 16/40 [00:09<00:14,  1.61it/s]

Loss: 2.805819347500801


 42%|████▎     | 17/40 [00:10<00:14,  1.60it/s]

Loss: 2.454251065850258


 45%|████▌     | 18/40 [00:11<00:13,  1.60it/s]

Loss: 2.142840698361397


 48%|████▊     | 19/40 [00:11<00:13,  1.60it/s]

Loss: 1.955758050084114


 50%|█████     | 20/40 [00:12<00:12,  1.60it/s]

Loss: 1.941683679819107


 52%|█████▎    | 21/40 [00:13<00:11,  1.60it/s]

Loss: 1.9002515375614166


 55%|█████▌    | 22/40 [00:13<00:11,  1.60it/s]

Loss: 1.7820546627044678


 57%|█████▊    | 23/40 [00:14<00:10,  1.60it/s]

Loss: 1.6159545630216599


 60%|██████    | 24/40 [00:15<00:09,  1.60it/s]

Loss: 1.5799636915326118


 62%|██████▎   | 25/40 [00:15<00:09,  1.60it/s]

Loss: 1.4991140812635422


 65%|██████▌   | 26/40 [00:16<00:08,  1.60it/s]

Loss: 1.5488112792372704


 68%|██████▊   | 27/40 [00:16<00:08,  1.60it/s]

Loss: 1.5270564630627632


 70%|███████   | 28/40 [00:17<00:07,  1.60it/s]

Loss: 1.4974652826786041


 72%|███████▎  | 29/40 [00:18<00:06,  1.60it/s]

Loss: 1.4815292656421661


 75%|███████▌  | 30/40 [00:18<00:06,  1.59it/s]

Loss: 1.600169599056244


 78%|███████▊  | 31/40 [00:19<00:05,  1.59it/s]

Loss: 1.5619353353977203


 80%|████████  | 32/40 [00:20<00:05,  1.59it/s]

Loss: 1.6001002192497253


 82%|████████▎ | 33/40 [00:20<00:04,  1.59it/s]

Loss: 1.535010363906622


 85%|████████▌ | 34/40 [00:21<00:03,  1.59it/s]

Loss: 1.4323739223182201


 88%|████████▊ | 35/40 [00:21<00:03,  1.59it/s]

Loss: 1.3853686228394508


 90%|█████████ | 36/40 [00:22<00:02,  1.60it/s]

Loss: 1.3676151558756828


 92%|█████████▎| 37/40 [00:23<00:01,  1.59it/s]

Loss: 1.3933877125382423


 95%|█████████▌| 38/40 [00:23<00:01,  1.59it/s]

Loss: 1.3824001923203468


 98%|█████████▊| 39/40 [00:24<00:00,  1.59it/s]

Loss: 1.3072790764272213


100%|██████████| 40/40 [00:25<00:00,  1.60it/s]
  "num_layers={}".format(dropout, num_layers))


Loss: 1.3202922195196152
[7.197408378124237, 6.583252012729645, 6.427386522293091, 6.262961506843567, 5.94216114282608, 5.541614472866058, 5.286610275506973, 5.107948929071426, 4.917902588844299, 4.756840378046036, 4.490712910890579, 4.2222690880298615, 3.8480521142482758, 3.436291068792343, 3.1411882638931274, 2.805819347500801, 2.454251065850258, 2.142840698361397, 1.955758050084114, 1.941683679819107, 1.9002515375614166, 1.7820546627044678, 1.6159545630216599, 1.5799636915326118, 1.4991140812635422, 1.5488112792372704, 1.5270564630627632, 1.4974652826786041, 1.4815292656421661, 1.600169599056244, 1.5619353353977203, 1.6001002192497253, 1.535010363906622, 1.4323739223182201, 1.3853686228394508, 1.3676151558756828, 1.3933877125382423, 1.3824001923203468, 1.3072790764272213, 1.3202922195196152]
1.1189643144607544
Accuracy: 0.7236842105263158
Precision, Recall, F1: (0.47368421052631576, 0.45, 0.46153846153846156, None)


  2%|▎         | 1/40 [00:00<00:23,  1.63it/s]

Loss: 7.007881164550781


  5%|▌         | 2/40 [00:01<00:23,  1.62it/s]

Loss: 6.514572739601135


  8%|▊         | 3/40 [00:01<00:22,  1.63it/s]

Loss: 6.295890629291534


 10%|█         | 4/40 [00:02<00:22,  1.63it/s]

Loss: 6.18398243188858


 12%|█▎        | 5/40 [00:03<00:21,  1.63it/s]

Loss: 6.002837538719177


 15%|█▌        | 6/40 [00:03<00:20,  1.63it/s]

Loss: 5.620742708444595


 18%|█▊        | 7/40 [00:04<00:20,  1.63it/s]

Loss: 5.33841010928154


 20%|██        | 8/40 [00:04<00:19,  1.63it/s]

Loss: 5.192352503538132


 22%|██▎       | 9/40 [00:05<00:19,  1.63it/s]

Loss: 4.962832897901535


 25%|██▌       | 10/40 [00:06<00:18,  1.63it/s]

Loss: 4.69882208108902


 28%|██▊       | 11/40 [00:06<00:17,  1.63it/s]

Loss: 4.4200833439826965


 30%|███       | 12/40 [00:07<00:17,  1.62it/s]

Loss: 3.935879409313202


 32%|███▎      | 13/40 [00:07<00:16,  1.62it/s]

Loss: 3.6734746396541595


 35%|███▌      | 14/40 [00:08<00:15,  1.63it/s]

Loss: 3.442569464445114


 38%|███▊      | 15/40 [00:09<00:15,  1.62it/s]

Loss: 3.0133329033851624


 40%|████      | 16/40 [00:09<00:14,  1.62it/s]

Loss: 2.601894363760948


 42%|████▎     | 17/40 [00:10<00:14,  1.63it/s]

Loss: 2.4236233830451965


 45%|████▌     | 18/40 [00:11<00:13,  1.62it/s]

Loss: 2.1495791003108025


 48%|████▊     | 19/40 [00:11<00:12,  1.63it/s]

Loss: 1.9845887273550034


 50%|█████     | 20/40 [00:12<00:12,  1.63it/s]

Loss: 1.8366825729608536


 52%|█████▎    | 21/40 [00:12<00:11,  1.62it/s]

Loss: 1.7102325335144997


 55%|█████▌    | 22/40 [00:13<00:11,  1.62it/s]

Loss: 1.8536003679037094


 57%|█████▊    | 23/40 [00:14<00:10,  1.62it/s]

Loss: 1.6579052060842514


 60%|██████    | 24/40 [00:14<00:09,  1.62it/s]

Loss: 1.522675298154354


 62%|██████▎   | 25/40 [00:15<00:09,  1.62it/s]

Loss: 1.3955547958612442


 65%|██████▌   | 26/40 [00:16<00:08,  1.62it/s]

Loss: 1.3329942002892494


 68%|██████▊   | 27/40 [00:16<00:08,  1.62it/s]

Loss: 1.2835547365248203


 70%|███████   | 28/40 [00:17<00:07,  1.62it/s]

Loss: 1.2127387151122093


 72%|███████▎  | 29/40 [00:17<00:06,  1.62it/s]

Loss: 1.2782229594886303


 75%|███████▌  | 30/40 [00:18<00:06,  1.62it/s]

Loss: 1.185468541458249


 78%|███████▊  | 31/40 [00:19<00:05,  1.62it/s]

Loss: 1.116219513118267


 80%|████████  | 32/40 [00:19<00:04,  1.63it/s]

Loss: 1.0587582141160965


 82%|████████▎ | 33/40 [00:20<00:04,  1.62it/s]

Loss: 1.033531367778778


 85%|████████▌ | 34/40 [00:20<00:03,  1.62it/s]

Loss: 1.041354551911354


 88%|████████▊ | 35/40 [00:21<00:03,  1.62it/s]

Loss: 1.041616465896368


 90%|█████████ | 36/40 [00:22<00:02,  1.62it/s]

Loss: 0.9723757393658161


 92%|█████████▎| 37/40 [00:22<00:01,  1.63it/s]

Loss: 0.9816540777683258


 95%|█████████▌| 38/40 [00:23<00:01,  1.63it/s]

Loss: 0.9321091827005148


 98%|█████████▊| 39/40 [00:24<00:00,  1.63it/s]

Loss: 0.9157585259526968


100%|██████████| 40/40 [00:24<00:00,  1.62it/s]

Loss: 0.900755463168025
[7.007881164550781, 6.514572739601135, 6.295890629291534, 6.18398243188858, 6.002837538719177, 5.620742708444595, 5.33841010928154, 5.192352503538132, 4.962832897901535, 4.69882208108902, 4.4200833439826965, 3.935879409313202, 3.6734746396541595, 3.442569464445114, 3.0133329033851624, 2.601894363760948, 2.4236233830451965, 2.1495791003108025, 1.9845887273550034, 1.8366825729608536, 1.7102325335144997, 1.8536003679037094, 1.6579052060842514, 1.522675298154354, 1.3955547958612442, 1.3329942002892494, 1.2835547365248203, 1.2127387151122093, 1.2782229594886303, 1.185468541458249, 1.116219513118267, 1.0587582141160965, 1.033531367778778, 1.041354551911354, 1.041616465896368, 0.9723757393658161, 0.9816540777683258, 0.9321091827005148, 0.9157585259526968, 0.900755463168025]
2.7549076080322266
Accuracy: 0.6776315789473685
Precision, Recall, F1: (0.4888888888888889, 0.4583333333333333, 0.4731182795698925, None)

===Aggregate Stats===
Accuracy: 0.7447368421052631
Precisio




In [None]:
trained_model = train(train_data_array, train_labels, 20, 128, modeltype = FullModel)
_, (p, l) = test_model(trained_model, test_data_array, test_labels)
p = torch.round(p).cpu().detach().numpy()
print('Accuracy:',accuracy_score(l,p))
print('Precision, Recall, F1:',precision_recall_fscore_support(l, p, average = 'binary'))

In [None]:
trained_model = train_with_titles(train_data_array, train_title_array, train_labels, 20, 128)
_, (p, l) = test_model_titles(trained_model, test_data_array, test_title_array, test_labels)
p = torch.round(p).cpu().detach().numpy()
print('Accuracy:',accuracy_score(l,p))
print('Precision, Recall, F1:',precision_recall_fscore_support(l, p, average = 'binary'))