# 0. Environment setup

In [None]:
# Library to read json
import json

# Numeric and data manipulation tools
import pandas as pd
import numpy as np
import random

# Deep learning framework
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Natural language tools
import nltk
import gensim
import gensim.downloader as gloader

# Other tools
from tqdm.notebook import tqdm
from collections import OrderedDict
from time import time
from itertools import zip_longest

# Abstract classes
from abc import ABC, abstractmethod

# automatic mixed precision training:
from torch.cuda.amp import autocast 
from torch.cuda.amp import GradScaler

# Type hint
from typing import Optional, Callable, Tuple, Dict, List, Union

nltk.download('punkt')
nltk.download('stopwords')

# from sklearn.model_selection import train_test_split

# Use GPU acceleration if possible
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("using this device:", DEVICE)

if not(torch.cuda.is_available()):
  raise Exception("switch to runtime GPU, otherwise the code won't work properly")

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
using this device: cuda:0


In [None]:
# Set seed for reproducibility
def fix_random(seed: int):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

fix_random(42)

In [None]:
# Use GPU acceleration if possible
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("using this device:", DEVICE)

# Define PAD token
PAD = '<PAD>'

using this device: cuda:0


In [None]:
# Lambda for computing the mean of a list
mean: Callable[[List[float]], float] = lambda l: sum(l) / len(l)

# Lambda for transforming a list of tuples into a tuple of lists
to_tuple_of_lists: Callable[[List[Tuple]], Tuple[List]] = lambda list_of_tuples: tuple(map(list, zip(*list_of_tuples)))

# Lambda for transforming a tuple of lists into a list of tuples
to_list_of_tuples: Callable[[Tuple[List]], List[Tuple]] = lambda tuple_of_lists: list(zip(*tuple_of_lists))

# Lambda for iterating with batches (if the length of the sequences does not match with the batch size, tuples of empty lists are appended)
batch_iteration: Callable[[List[Tuple]], zip] = lambda data, batch_size: zip_longest(*[iter(data)] * batch_size, fillvalue=([], [], ''))

# 1. Dataset preparation

In [None]:
""" OLD CODE
filename = 'training_set.json'

with open(filename, 'r') as f:
    raw_data = f.readlines()[0]

parsed_data = json.loads(raw_data)['data']

context_list = []
context_index = -1
title_list = []
title_index = -1

dataset = {'title_index': [], 'context_index': [], 'question': [], 'answer_start': [], 'answer_end': [], 'answer_text': []}

for i in range(len(parsed_data)):
    title_list.append(parsed_data[i]['title'])
    title_index += 1
    for j in range(len(parsed_data[i]['paragraphs'])):
        context_list.append(parsed_data[i]['paragraphs'][j]['context'])
        context_index += 1

        for k in range(len(parsed_data[i]['paragraphs'][j]['qas'])):
            question = parsed_data[i]['paragraphs'][j]['qas'][k]['question']

            for l in range(len(parsed_data[i]['paragraphs'][j]['qas'][k]['answers'])): 
                answer_start = parsed_data[i]['paragraphs'][j]['qas'][k]['answers'][l]['answer_start']
                answer_text = parsed_data[i]['paragraphs'][j]['qas'][k]['answers'][l]['text']

                answer_end = answer_start + len(answer_text)

                dataset['title_index'].append(title_index)
                dataset['context_index'].append(context_index)
                dataset['question'].append(question)
                dataset['answer_start'].append(answer_start)
                dataset['answer_end'].append(answer_end)
                dataset['answer_text'].append(answer_text)

df = pd.DataFrame.from_dict(dataset)

df.head()
------------------------------------------------------------------------------------------------------------
"""

"""
json structure:

data []
|---title
|---paragraphs []
|   |---context
|   |---qas []
|   |   |---answers []
|   |   |   |---answer_start
|   |   |   |---text
|   |   |---question
|   |   |---id
version

"""

filename = 'training_set.json'

with open(filename, 'r') as f:
    raw_data = f.readlines()[0]

parsed_data = json.loads(raw_data)['data']

context_list = []
context_index = -1
paragraph_index = -1

dataset = {'paragraph_index': [], 'context_index': [], 'question': [], 'answer_start': [], 'answer_end': [], 'answer_text': []}

for i in range(len(parsed_data)):
    paragraph_index += 1
    for j in range(len(parsed_data[i]['paragraphs'])):
        context_list.append(parsed_data[i]['paragraphs'][j]['context'])
        context_index += 1

        for k in range(len(parsed_data[i]['paragraphs'][j]['qas'])):
            question = parsed_data[i]['paragraphs'][j]['qas'][k]['question']

            for l in range(len(parsed_data[i]['paragraphs'][j]['qas'][k]['answers'])): 
                answer_start = parsed_data[i]['paragraphs'][j]['qas'][k]['answers'][l]['answer_start']
                answer_text = parsed_data[i]['paragraphs'][j]['qas'][k]['answers'][l]['text']

                answer_end = answer_start + len(answer_text)

                dataset['paragraph_index'].append(paragraph_index)
                dataset['context_index'].append(context_index)
                dataset['question'].append(question)
                dataset['answer_start'].append(answer_start)
                dataset['answer_end'].append(answer_end)
                dataset['answer_text'].append(answer_text)

df = pd.DataFrame.from_dict(dataset)

df.head()

Unnamed: 0,paragraph_index,context_index,question,answer_start,answer_end,answer_text
0,0,0,To whom did the Virgin Mary allegedly appear i...,515,541,Saint Bernadette Soubirous
1,0,0,What is in front of the Notre Dame Main Building?,188,213,a copper statue of Christ
2,0,0,The Basilica of the Sacred heart at Notre Dame...,279,296,the Main Building
3,0,0,What is the Grotto at Notre Dame?,381,420,a Marian place of prayer and reflection
4,0,0,What sits on top of the Main Building at Notre...,92,126,a golden statue of the Virgin Mary


In [None]:
# L'unica cosa che si potrebbe rimuovere è la fonetica, ma è poco presente. Inoltre facendolo bisognerebbe aggiornare gli indici

# Some examples of contexts and questions:
for i in range(0, 4000, 100):
    # print('Title:   ', title_list[df['title_index'][i]])
    print('Context: ', context_list[df['context_index'][i]])
    print('Question:', df['question'][i], "\n")

Context:  Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.
Question: To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? 

Context:  One of the main driving forces in the growth of the University was its football team, the Notre Dame Fighting Irish. Knute Rockne became head coach in 1918. Under Rockne, the Irish would post a record

In [None]:
"""
--- OLD CODE ---
# Split eseguito secondo quanto scritto dai tutor: 
'''
   If you split the dataset in training and validation, we suggest you to do the splitting based on the title: 
   all the questions/paragraphs regarding the same title should be in the same split.
'''
train_and_val_df, test_df = train_test_split(df, test_size=0.15)
train_df, val_df = train_test_split(train_and_val_df, test_size=0.15)
val_df_unmodifiable = val_df.copy(deep=True)

print(train_df.shape)
print(val_df.shape)

# Questo codice probabilmente non va bene: essendoci solo 441 titoli distinti lo split randomico li mette tutti sia in val test che in train test
for _, row in val_df_unmodifiable.iterrows():
  if row["title_index"] in train_df["title_index"].tolist():
    tmp = pd.DataFrame([[row["title_index"], 
                         row["context_index"],
                         row["question"],
                         row["answer_start"],
                         row["answer_end"],
                         row["answer_text"]]], 
                       columns=["title_index", "context_index", "question",
                                "answer_start", "answer_end", "answer_text"])
    train_df = train_df.append(tmp)
    val_df = val_df[val_df['title_index'] != row["title_index"]]

print(train_df.shape)
print(val_df.shape)

np.sort(train_df['title_index'].unique())
np.sort(val_df_unmodifiable['title_index'].unique())
"""

# Define split ratios
test_ratio = 0.2
val_ratio = 0.2

# Build array of paragraphs indexes and shuffle them
paragraph_indexes = df['paragraph_index'].unique()
np.random.shuffle(paragraph_indexes)
n_samples = len(paragraph_indexes)

# Reserve indexes for test set
test_size = int(test_ratio * n_samples)
train_val_size = n_samples - test_size
test_indexes = paragraph_indexes[-test_size:]
# Reserve indexes for validation set
val_size = int(val_ratio * train_val_size)
train_size = train_val_size - val_size
val_indexes = paragraph_indexes[-(test_size + val_size):-test_size]
# Reserve indexes for training set
train_indexes = paragraph_indexes[:train_size]

assert train_size == len(train_indexes), 'Something went wrong with train set slicing'
assert val_size == len(val_indexes), 'Something went wrong with val set slicing'
assert test_size == len(test_indexes), 'Something went wrong with test set slicing'

print('Number of train paragraphs:', train_size)
print('Number of validation paragraphs:', val_size)
print('Number of test paragraphs:', test_size)

# Split dataframe
df_train = df[np.in1d(df['paragraph_index'], train_indexes)]
df_val = df[np.in1d(df['paragraph_index'], val_indexes)]
df_test = df[np.in1d(df['paragraph_index'], test_indexes)]

print('\nNumber of train samples:', len(df_train))
print('Number of validation samples:', len(df_val))
print('Number of test samples:', len(df_test))

Number of train paragraphs: 284
Number of validation paragraphs: 70
Number of test paragraphs: 88

Number of train samples: 57451
Number of validation samples: 12921
Number of test samples: 17227


# 2. Word embeddings

In [None]:
print('Downloading GloVe model...')
emb_dim = 50
glove_model = gloader.load('glove-wiki-gigaword-' + str(emb_dim))
print('\nDownload completed.')

Downloading GloVe model...

Download completed.


## Play with a little dataset

In [None]:
if True:
  df_train = df_train[:100]
  df_val = df_val[:100]
  df_test = df_test[:100]

In [None]:
def build_vocabulary(corpus: List[List[str]],
                     old_word_listing: Optional[List[str]] = None) -> (Dict[int, str], Dict[int, str], List[str]):
    
    flat_tokens = [x for sub in corpus for x in sub]
    
    if old_word_listing is None:  # standard case
        word_listing = [PAD] + list(OrderedDict.fromkeys(flat_tokens))
    else:  # case in which we extend an already existing vocabulary
        word_listing = list(OrderedDict.fromkeys(old_word_listing + flat_tokens))
        
    idx_to_word = {i: w for i, w in enumerate(word_listing)}
    word_to_idx = {w: i for i, w in enumerate(word_listing)}

    return idx_to_word, word_to_idx, word_listing

# Tokenize corpus
Contexts = df_train['context_index'].apply(lambda x: context_list[x])
X_trainC = Contexts.apply(lambda x: nltk.word_tokenize(x))

X_trainQ = df_train['question'].apply(lambda x: nltk.word_tokenize(x))
Y_train = [(start, end) for start, end in zip(df_train['answer_start'].tolist(), df_train['answer_end'].tolist())]

X_train = X_trainC.tolist() + X_trainQ.tolist()

Contexts = df_val['context_index'].apply(lambda x: context_list[x])
X_valC = Contexts.apply(lambda x: nltk.word_tokenize(x))
X_valQ = df_val['question'].apply(lambda x: nltk.word_tokenize(x))
Y_val = [(start, end) for start, end in zip(df_val['answer_start'].tolist(), df_val['answer_end'].tolist())]

X_val = X_valC.tolist() + X_valQ.tolist()

Contexts = df_val['context_index'].apply(lambda x: context_list[x])
X_testC = Contexts.apply(lambda x: nltk.word_tokenize(x))
X_testQ = df_test['question'].apply(lambda x: nltk.word_tokenize(x))
Y_test = [(start, end) for start, end in zip(df_test['answer_start'].tolist(), df_test['answer_end'].tolist())]

X_test = X_testC.tolist() + X_testQ.tolist()

# Get word mappings for each set
train_idx_to_word, train_word_to_idx, train_word_listing = build_vocabulary(X_train)
val_idx_to_word, val_word_to_idx, val_word_listing = build_vocabulary(X_val, old_word_listing=train_word_listing)
test_idx_to_word, test_word_to_idx, test_word_listing = build_vocabulary(X_test, old_word_listing=val_word_listing)

print('Words in training set:', len(train_word_listing))
print('Words in validation set:', len(val_word_listing))
print('Words in test set:', len(test_word_listing))

Words in training set: 1338
Words in validation set: 1840
Words in test set: 2015


In [None]:
train_oov_words = [word for word in train_word_listing if word not in glove_model.vocab and word != PAD]
val_oov_words = [word for word in val_word_listing if word not in glove_model.vocab and word != PAD]
test_oov_words = [word for word in test_word_listing if word not in glove_model.vocab and word != PAD]

print(f'Total OOV terms in training set: {len(train_oov_words)} ({float(len(train_oov_words)) / len(train_word_listing) * 100:.2f}%)')
print(f'Total OOV terms in validation set: {len(val_oov_words)} ({float(len(val_oov_words)) / len(val_word_listing) * 100:.2f}%)')
print(f'Total OOV terms in test set: {len(test_oov_words)} ({float(len(test_oov_words)) / len(test_word_listing) * 100:.2f}%)')

Total OOV terms in training set: 366 (27.35%)
Total OOV terms in validation set: 444 (24.13%)
Total OOV terms in test set: 488 (24.22%)


In [None]:
def build_embedding_matrix(embedding_model: gensim.models.keyedvectors.Word2VecKeyedVectors,
                           word_to_idx: Dict[str, int],
                           oov_words: List[str],
                           old_embedding_matrix: Optional[np.ndarray] = None):
    # Initialize embedding matrix with all zeros
    embedding_matrix = np.zeros((len(word_to_idx), embedding_model.vector_size))
    
    # Helper function to analyze embeddings
    def analyze_embedding() -> (float, float):
        mean_list, std_list = [], []
        for word in tqdm(word_to_idx.keys(), leave=False):
            if word not in oov_words and word != PAD:
                embed = embedding_model[word]
                # Compute mean and std
                mean_list.append(np.mean(embed))
                std_list.append(np.std(embed))

        embedding_mean = mean(mean_list)
        embedding_std = mean(std_list)

        return embedding_mean, embedding_std
    
    # Get mean and std for the embeddings
    embedding_mean, embedding_std = analyze_embedding()

    for word, idx in tqdm(word_to_idx.items(), leave=False):
        # If word is PAD no action is performed (it will be assigned the zero vector)
        if word not in oov_words and word != PAD:
            embedding_matrix[idx] = embedding_model[word]
        elif word in oov_words:
            embedding_matrix[idx] = np.random.normal(loc=embedding_mean, scale=embedding_std, size=embedding_model.vector_size)
            
    return embedding_matrix

# Build embedding matrix based only on the training set (for training)
train_embedding_matrix = build_embedding_matrix(glove_model, train_word_to_idx, train_oov_words)
print('Shape of embedding matrix (training set):', train_embedding_matrix.shape)

# Build embedding matrix based on training + validation set (for validation)
val_embedding_matrix = build_embedding_matrix(glove_model, val_word_to_idx, val_oov_words, train_embedding_matrix)
print('Shape of embedding matrix (validation set):', val_embedding_matrix.shape)

# Build embedding matrix based on training + validation + test set (for test)
test_embedding_matrix = build_embedding_matrix(glove_model, test_word_to_idx, test_oov_words, val_embedding_matrix)
print('Shape of embedding matrix (test set):', test_embedding_matrix.shape)

HBox(children=(FloatProgress(value=0.0, max=1338.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=1338.0), HTML(value='')))

Shape of embedding matrix (training set): (1338, 50)


HBox(children=(FloatProgress(value=0.0, max=1840.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=1840.0), HTML(value='')))

Shape of embedding matrix (validation set): (1840, 50)


HBox(children=(FloatProgress(value=0.0, max=2015.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=2015.0), HTML(value='')))

Shape of embedding matrix (test set): (2015, 50)


In [None]:
class Embedder:
    def __init__(self,
                 embedding_matrix: np.ndarray,
                 word_to_idx: Dict[str, int]):
        self.embedding_matrix = embedding_matrix
        self.word_to_idx = word_to_idx
    
    def get_word_embedding(self,
                           batch_word_seq: Tuple[List[str]],
                           max_len: Optional[int] = None) -> (torch.FloatTensor, torch.IntTensor):
        if max_len is None:
          # Find maximum length of batch
          max_len = max(map(len, batch_word_seq))

        batch_embedding = []
        lengths = []
        for word_seq in batch_word_seq:
            if word_seq:  # Handle partial batches
                orig_len = len(word_seq)
                pad_len = max_len - orig_len
                # Keep track of original length
                lengths.append(orig_len)

                # Pad sequence
                padded_seq = word_seq + [PAD] * pad_len
                # Embed padded sequence
                embedded_seq = [self.embedding_matrix[self.word_to_idx[w]] for w in padded_seq]
                batch_embedding.append(embedded_seq)
        
        return torch.cuda.FloatTensor(batch_embedding, device=DEVICE), torch.cuda.IntTensor(lengths, device=DEVICE)

# Create embedding helpers
train_embedder = Embedder(train_embedding_matrix, train_word_to_idx)
val_embedder = Embedder(val_embedding_matrix, val_word_to_idx)
test_embedder = Embedder(test_embedding_matrix, test_word_to_idx)

## Prepare data for training

In [None]:
train_data = to_list_of_tuples((X_trainC.tolist(), X_trainQ.tolist(), Y_train))
val_data = to_list_of_tuples((X_valC.tolist(), X_valQ.tolist(), Y_val))
test_data = to_list_of_tuples((X_testC.tolist(), X_testQ.tolist(), Y_test))

# 3. Define model

In [None]:
class SentenceEncoder(ABC, nn.Module):
    @abstractmethod
    def get_output_dim(self):
        pass
    @abstractmethod
    def forward(self, x: torch.Tensor, x_lengths: Optional[torch.IntTensor] = None) -> torch.Tensor:
        pass

In [None]:
class RNNAverage(SentenceEncoder):
    def __init__(self,
                 input_size: int,
                 hidden_size: int,
                 bidirectional: Optional[bool] = True,
                 num_layers: Optional[int] = 1):
        super(RNNAverage, self).__init__()
        
        self.lstm = nn.LSTM(input_size=input_size,
                            hidden_size=hidden_size,
                            bidirectional=bidirectional,
                            batch_first=True,
                            num_layers = num_layers)
        
        self.bidirectional = bidirectional
        self.num_directions = 2 if self.bidirectional else 1
        self.num_layers = num_layers
 
    def get_output_dim(self):
        return self.lstm.hidden_size * self.num_directions

    def _get_lstm_features(self, x: torch.Tensor, x_lengths: torch.IntTensor) -> torch.Tensor:
        # Ignore padding in LSTM
        x = nn.utils.rnn.pack_padded_sequence(x, x_lengths.cpu(), batch_first=True, enforce_sorted=False) # lengths.cpu() needed because of pytorch bug
        # Feed into LSTM
        x, _ = self.lstm(x)
        # Undo packaging
        x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)

        return x
    
    def forward(self, x: torch.Tensor, x_lengths: Optional[torch.IntTensor] = None) -> torch.Tensor:
        batch_size, seq_len, _ = x.size()
        # Re-initialize hidden layer before each batch, otherwise the LSTM will treat
        # a new batch as a continuation of a sequence
        
        # Feed into LSTM
        x = self._get_lstm_features(x, x_lengths)
        
        # Return the mean along second dim (taking into account padding)
        return torch.sum(x, dim=1) / x_lengths.unsqueeze(dim=1)

In [None]:
class ClaimProver(nn.Module):
    def __init__(self,
                 encoder: SentenceEncoder,
                 classifier: nn.Module,
                 merge_strategy: Optional[str] = 'concat',
                 use_cos_sim: Optional[bool] = False):
        super(ClaimProver, self).__init__()

        self.encoder = encoder
        self.classifier = classifier
        self.merge_strategy = merge_strategy
        self.use_cos_sim = use_cos_sim
    
    def forward(self,
                x1: torch.Tensor,
                x2: torch.Tensor,
                x1_lengths: Optional[torch.IntTensor] = None,
                x2_lengths: Optional[torch.IntTensor] = None):
        # Encode both inputs
        
        
        x1 = self.encoder(x1, x1_lengths)
        x2 = self.encoder(x2, x2_lengths)
        # Merge according to strategy
        if self.merge_strategy == 'sum':
            encoding = x1 + x2  # SUM
        elif self.merge_strategy == 'avg':
            encoding = (x1 + x2) / 2  # AVERAGE
        elif self.merge_strategy == 'concat':
            encoding = torch.cat((x1, x2), dim=1)  # CONCATENATION
        # Optionally add cosine similarity feature
        if self.use_cos_sim:
            # Compute cosine similarity between corresponding embeddings inside the batch:
            # the slices are reshaped to (1, batch) to match cosine_similarity API specifications for "single sample" case
            cos_sim = [cosine_similarity(slice_x1.view(1, -1),
                                         slice_x2.view(1, -1)) for slice_x1, slice_x2 in zip(x1.detach().cpu(), x2.detach().cpu())]
            # Convert the list into torch tensor
            cos_sim = torch.from_numpy(np.concatenate(cos_sim, axis=0)).to(DEVICE)
            # Concatenate cosine similarity to original tensor:
            # encoding has shape (batch, emb_size), cos_sim has shape (batch, 1)
            # torch.cat requires shapes (emb_size, batch) and (1, batch), so both are transposed
            encoding = torch.cat((encoding.transpose(0, 1), cos_sim.transpose(0, 1)), dim=0).transpose(0, 1)  # transpose to restore original shape (batch, emb_size + 1)
        
        out = self.classifier(encoding)
        
        return out

# 4. Define training loop

In [None]:
def train(model: ClaimProver,
          data: List[Tuple[List[str], List[str], str]],
          batch_size: int,
          criterion: Callable[[torch.Tensor, torch.Tensor], float],
          optimizer: torch.optim,
          embedder: Embedder,
          max_len: Optional[int] = None,
          verbose: Optional[bool] = False,
          scaler: Optional[torch.cuda.amp.grad_scaler.GradScaler] = False) -> (float, float, float):
    loss_data = []
    correct = 0
    total = 0
    i = 0
    
    # Create batch iterator
    batch_iter = batch_iteration(data, batch_size)
    steps = len(data) // batch_size if len(data) % batch_size == 0 else len(data) // batch_size + 1
    if verbose:
        batch_iter = tqdm(batch_iter, total=steps, leave=False)
    for batch in batch_iter:
        
        # Extract samples
        (claims, evidences, labels) = to_tuple_of_lists(batch)
        
        # Get encoding of claims, evidences and labels, and move them to GPU if available
        claims_tensor, claims_lengths = embedder.get_word_embedding(claims, max_len)
        evidences_tensor, evidences_lengths = embedder.get_word_embedding(evidences, max_len)
        labels_tensor = torch.cuda.FloatTensor([l[0] for l in labels if len(l) > 0], device = DEVICE)   ################################ to predict only the start of the span
        
        if scaler:
           
            # Make prediction
            optimizer.zero_grad()
            with autocast():   #https://pytorch.org/docs/stable/notes/amp_examples.html
                scores = model(claims_tensor, evidences_tensor, claims_lengths, evidences_lengths)
                loss = criterion(scores, labels_tensor.unsqueeze(dim=1))
            # Backpropagation
        
            scaler.scale(loss).backward()  #https://pytorch.org/docs/stable/notes/amp_examples.html
            scaler.step(optimizer) #https://pytorch.org/docs/stable/notes/amp_examples.html
            scaler.update() #https://pytorch.org/docs/stable/notes/amp_examples.html
            
        else:
            # Make prediction
            optimizer.zero_grad()
            scores = model(claims_tensor, evidences_tensor, claims_lengths, evidences_lengths)
            loss = criterion(scores, labels_tensor.unsqueeze(dim=1))
            
            # Backpropagation
            loss.backward()  
            optimizer.step() 

        # Compute accuracy
        total_batch = scores.shape[0]
        predictions = torch.flatten(scores)
       
        if i == 0:
            print("Example of prediction vs label:",int(predictions[0].item()), int(labels_tensor[0].item()))
            i+=1
        correct_batch = (predictions.type(torch.int16)  == labels_tensor.type(torch.int16)).sum().item()
        
        # Update history
        loss_data.append(loss.item())
        correct += correct_batch
        total += total_batch
        

        
    return mean(loss_data), correct / total * 100

def evaluate(model: nn.Module,
             data: List[Tuple[List[str], List[str], str]],
             batch_size: int,
             criterion: Callable[[torch.Tensor, torch.Tensor], float],
             embedder: Embedder,
             label_to_idx: Dict[str, int],
             max_len: Optional[int] = None,
             verbose: Optional[bool] = False) -> (float, float, float):
    loss_data = []
    correct = 0
    total = 0
    
    with torch.no_grad():
        # Create batch iterator
        batch_iter = batch_iteration(data, batch_size)
        steps = len(data) // batch_size if len(data) % batch_size == 0 else len(data) // batch_size + 1
        if verbose:
            batch_iter = tqdm(batch_iter, total=steps, leave=False)
        for batch in batch_iter:
            # Extract samples
            (claims, evidences, labels) = to_tuple_of_lists(batch)
            
            # Get encoding of claims, evidences and labels, and move them to GPU if available
            claims_tensor, claims_lengths = embedder.get_word_embedding(claims, max_len)
            evidences_tensor, evidences_lengths = embedder.get_word_embedding(evidences, max_len)
            labels_tensor = torch.cuda.FloatTensor([label_to_idx[l] for l in labels if len(l) > 0], device = DEVICE)
            
            # Make prediction
            scores = model(claims_tensor, evidences_tensor, claims_lengths, evidences_lengths)

            # Compute loss
            loss = criterion(scores, labels_tensor.unsqueeze(dim=1))
            
            # Compute accuracy
            total_batch = scores.shape[0]
            scores = torch.sigmoid(scores)
            predictions = scores.squeeze() > 0.5
            
            correct_batch = (int(predictions) == int(labels_tensor)).sum().item()
            
        
            # Update history
            loss_data.append(loss.item())
            correct += correct_batch

            total += total_batch
              
    return mean(loss_data), correct / total * 100
    

def training_loop(model: nn.Module,
                  train_data: List[Tuple[List[str], List[str]]],
                  optimizer: torch.optim,
                  epochs: int,
                  batch_size: int,
                  criterion: Callable[[torch.Tensor, torch.Tensor], float],
                  train_embedder: Embedder,
                  val_embedder: Optional[Embedder] = None,
                  max_len: Optional[int] = None,
                  lr_scheduler: torch.optim.lr_scheduler = None,
                  val_data: Optional[List[Tuple[List[str], List[str]]]] = None,
                  early_stopping: Optional[bool] = False,
                  patience: Optional[int] = 5,
                  tolerance: Optional[float] = 1e-4,
                  checkpoint_path: Optional[str] = None,
                  verbose: Optional[bool] = True,
                  seed: Optional[int] = 42,
                  mix_scale: Optional[bool] = True) -> (Dict[str, List[float]]):
    # Set seed for reproducibility
    if seed:
        random.seed(seed)

    history = {'loss': [],
               'accuracy': [],
               'val_loss': [],
               'val_accuracy': []}
    
    # Initialize variables for early stopping (if required)
    if val_data and val_embedder and early_stopping:
        min_val_loss = np.inf
        no_improve_counter = 0
    
    scaler = GradScaler() if mix_scale else False   # https://pytorch.org/docs/stable/notes/amp_examples.html
    
    for ep in range(epochs):
        if verbose:
            print('-' * 100)
            print(f'Epoch {ep + 1}/{epochs}')
        
        # Shuffle training set at each epoch
        random.shuffle(train_data)
        
        start = time()
        train_loss, train_accuracy = train(model, train_data, batch_size, criterion, optimizer, train_embedder, max_len, verbose, scaler)
        end = time()
        
        history['loss'].append(train_loss)
        history['accuracy'].append(train_accuracy)
        if verbose:
            print(f'\tLoss: {train_loss:.5f} - Accuracy: {train_accuracy:.2f}% [Time elapsed: {end - start:.2f} s]')
        
        # Do validation if required
        if val_data and val_embedder:
            
            # Shuffle validation set at each epoch
            random.shuffle(val_data)
            
            start = time()
            val_loss, val_accuracy = evaluate(model, val_data, batch_size, criterion, val_embedder, label_to_idx, max_len, verbose)
            end = time()
            
            history['val_loss'].append(val_loss)
            history['val_accuracy'].append(val_accuracy)
            if verbose:
                print(f'\tValidation loss: {val_loss:.5f} - Validation accuracy: {val_accuracy:.2f}% [Time elapsed: {end - start:.2f} s]')
            
            if early_stopping and checkpoint_path:
                # If validation loss is lower than minimum, update minimum
                if val_loss < min_val_loss - tolerance:
                    min_val_loss = val_loss
                    no_improve_counter = 0
                    
                    # Save model
                    torch.save(model.state_dict(), checkpoint_path)
                # otherwise increment counter
                else:
                    no_improve_counter += 1
                # If loss did not improve for 'patience' epochs, break
                if no_improve_counter == patience:
                    if verbose:
                        print(f'Early stopping: no improvement in validation loss for {patience} epochs from {min_val_loss:.5f}')
                    # Restore model to best
                    model.load_state_dict(torch.load(checkpoint_path))
                    model.eval()
                    break
        
        # If lr scheduling is used, invoke next step
        if lr_scheduler:
            lr_scheduler.step()
    
    return history

# 5. Run training

In [None]:
# Epochs and batch size
EP = 100
BS = 1
merge_strategy = 'concat'

# regression loss:
criterion = nn.MSELoss() 

# Build model
encoder = RNNAverage(input_size=emb_dim , hidden_size=32, bidirectional = True, num_layers =1).to(DEVICE)  # 32 * 2 = 64 due to bidirectionality
classifier_in_features = encoder.get_output_dim() * 2 if merge_strategy == 'concat' else encoder.get_output_dim()
classifier = nn.Linear(in_features=classifier_in_features, out_features=1)
prover = ClaimProver(encoder, classifier).to(DEVICE)

In [None]:
# Define Adam optimizer
optimizer = torch.optim.Adam(prover.parameters(), lr=1e-1)

history = training_loop(model=prover,
                        train_data=train_data,
                        optimizer=optimizer,
                        epochs=EP,
                        batch_size=BS,
                        criterion=criterion,
                        train_embedder=train_embedder,
                        val_embedder=val_embedder,
                        val_data=None,
                        early_stopping=True,
                        checkpoint_path='rnn_last.pt')


----------------------------------------------------------------------------------------------------
Epoch 1/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 0.063232421875 488.0
	Loss: 157648.38610 - Accuracy: 0.00% [Time elapsed: 2.01 s]
----------------------------------------------------------------------------------------------------
Epoch 2/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 216.5 80.0
	Loss: 119211.48375 - Accuracy: 0.00% [Time elapsed: 1.80 s]
----------------------------------------------------------------------------------------------------
Epoch 3/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 369.75 68.0
	Loss: 115449.80188 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 4/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 348.25 1237.0
	Loss: 113353.03188 - Accuracy: 0.00% [Time elapsed: 1.94 s]
----------------------------------------------------------------------------------------------------
Epoch 5/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 331.75 126.0
	Loss: 113457.83313 - Accuracy: 0.00% [Time elapsed: 1.85 s]
----------------------------------------------------------------------------------------------------
Epoch 6/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 355.5 32.0
	Loss: 112546.46250 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 7/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 345.75 1237.0
	Loss: 110900.38375 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 8/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 321.0 71.0
	Loss: 110029.10875 - Accuracy: 0.00% [Time elapsed: 1.78 s]
----------------------------------------------------------------------------------------------------
Epoch 9/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 337.75 188.0
	Loss: 107645.19500 - Accuracy: 0.00% [Time elapsed: 1.80 s]
----------------------------------------------------------------------------------------------------
Epoch 10/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 371.25 908.0
	Loss: 108988.61250 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 11/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 332.25 624.0
	Loss: 106687.28234 - Accuracy: 0.00% [Time elapsed: 1.80 s]
----------------------------------------------------------------------------------------------------
Epoch 12/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 317.5 496.0
	Loss: 107792.19750 - Accuracy: 0.00% [Time elapsed: 1.80 s]
----------------------------------------------------------------------------------------------------
Epoch 13/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 302.75 32.0
	Loss: 106464.67062 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 14/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 321.5 279.0
	Loss: 106040.94875 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 15/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 412.75 222.0
	Loss: 105156.89250 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 16/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 393.25 862.0
	Loss: 104081.08812 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 17/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 306.0 213.0
	Loss: 106872.87000 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 18/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 384.0 675.0
	Loss: 106765.42313 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 19/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 300.75 126.0
	Loss: 102998.12938 - Accuracy: 0.00% [Time elapsed: 1.80 s]
----------------------------------------------------------------------------------------------------
Epoch 20/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 324.75 248.0
	Loss: 101148.64750 - Accuracy: 0.00% [Time elapsed: 1.83 s]
----------------------------------------------------------------------------------------------------
Epoch 21/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 308.0 8.0
	Loss: 103291.30875 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 22/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 384.75 385.0
	Loss: 102226.84938 - Accuracy: 0.00% [Time elapsed: 1.82 s]
----------------------------------------------------------------------------------------------------
Epoch 23/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 292.0 0.0
	Loss: 110607.37813 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 24/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 344.5 1049.0
	Loss: 108818.36125 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 25/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 372.5 515.0
	Loss: 106634.01688 - Accuracy: 0.00% [Time elapsed: 1.80 s]
----------------------------------------------------------------------------------------------------
Epoch 26/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 331.0 46.0
	Loss: 106878.65062 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 27/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 322.0 66.0
	Loss: 109414.63375 - Accuracy: 0.00% [Time elapsed: 1.82 s]
----------------------------------------------------------------------------------------------------
Epoch 28/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 325.25 362.0
	Loss: 109425.78500 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 29/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 382.25 377.0
	Loss: 107247.12250 - Accuracy: 0.00% [Time elapsed: 1.80 s]
----------------------------------------------------------------------------------------------------
Epoch 30/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 360.5 123.0
	Loss: 104777.19375 - Accuracy: 0.00% [Time elapsed: 1.83 s]
----------------------------------------------------------------------------------------------------
Epoch 31/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 355.0 251.0
	Loss: 102354.19500 - Accuracy: 0.00% [Time elapsed: 1.84 s]
----------------------------------------------------------------------------------------------------
Epoch 32/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 338.5 1099.0
	Loss: 103106.98250 - Accuracy: 0.00% [Time elapsed: 1.82 s]
----------------------------------------------------------------------------------------------------
Epoch 33/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 294.0 136.0
	Loss: 101934.01938 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 34/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 336.5 6.0
	Loss: 102435.12312 - Accuracy: 0.00% [Time elapsed: 1.87 s]
----------------------------------------------------------------------------------------------------
Epoch 35/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 367.75 0.0
	Loss: 103366.98375 - Accuracy: 1.00% [Time elapsed: 1.80 s]
----------------------------------------------------------------------------------------------------
Epoch 36/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 369.5 368.0
	Loss: 100504.76125 - Accuracy: 0.00% [Time elapsed: 1.82 s]
----------------------------------------------------------------------------------------------------
Epoch 37/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 433.25 862.0
	Loss: 98604.34750 - Accuracy: 0.00% [Time elapsed: 1.80 s]
----------------------------------------------------------------------------------------------------
Epoch 38/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 336.5 73.0
	Loss: 97660.60375 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 39/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 295.5 68.0
	Loss: 96237.43109 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 40/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 437.0 598.0
	Loss: 96766.98750 - Accuracy: 0.00% [Time elapsed: 1.83 s]
----------------------------------------------------------------------------------------------------
Epoch 41/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 337.5 145.0
	Loss: 94766.34187 - Accuracy: 0.00% [Time elapsed: 1.84 s]
----------------------------------------------------------------------------------------------------
Epoch 42/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 521.5 675.0
	Loss: 96377.14469 - Accuracy: 0.00% [Time elapsed: 1.82 s]
----------------------------------------------------------------------------------------------------
Epoch 43/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 317.75 145.0
	Loss: 96438.21016 - Accuracy: 0.00% [Time elapsed: 1.82 s]
----------------------------------------------------------------------------------------------------
Epoch 44/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 280.25 358.0
	Loss: 95125.41750 - Accuracy: 0.00% [Time elapsed: 1.84 s]
----------------------------------------------------------------------------------------------------
Epoch 45/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 317.25 49.0
	Loss: 93850.37625 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 46/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 200.625 908.0
	Loss: 105389.23422 - Accuracy: 0.00% [Time elapsed: 1.79 s]
----------------------------------------------------------------------------------------------------
Epoch 47/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 531.0 647.0
	Loss: 95212.22406 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 48/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 306.5 4.0
	Loss: 94662.26891 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 49/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 230.5 86.0
	Loss: 95964.35516 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 50/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 293.75 92.0
	Loss: 94207.11141 - Accuracy: 0.00% [Time elapsed: 1.80 s]
----------------------------------------------------------------------------------------------------
Epoch 51/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 361.0 753.0
	Loss: 95062.84109 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 52/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 413.5 598.0
	Loss: 93409.73078 - Accuracy: 0.00% [Time elapsed: 1.83 s]
----------------------------------------------------------------------------------------------------
Epoch 53/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 465.0 675.0
	Loss: 88797.93703 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 54/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 371.0 92.0
	Loss: 86660.60406 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 55/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 441.0 595.0
	Loss: 84027.33516 - Accuracy: 0.00% [Time elapsed: 1.80 s]
----------------------------------------------------------------------------------------------------
Epoch 56/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 334.5 68.0
	Loss: 87124.05766 - Accuracy: 0.00% [Time elapsed: 1.79 s]
----------------------------------------------------------------------------------------------------
Epoch 57/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 330.25 136.0
	Loss: 80896.65813 - Accuracy: 0.00% [Time elapsed: 1.79 s]
----------------------------------------------------------------------------------------------------
Epoch 58/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 215.0 0.0
	Loss: 76822.34594 - Accuracy: 0.00% [Time elapsed: 1.79 s]
----------------------------------------------------------------------------------------------------
Epoch 59/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 176.625 0.0
	Loss: 80232.46578 - Accuracy: 0.00% [Time elapsed: 1.80 s]
----------------------------------------------------------------------------------------------------
Epoch 60/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 340.25 908.0
	Loss: 89728.99145 - Accuracy: 0.00% [Time elapsed: 1.80 s]
----------------------------------------------------------------------------------------------------
Epoch 61/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 147.0 8.0
	Loss: 88560.35781 - Accuracy: 0.00% [Time elapsed: 1.79 s]
----------------------------------------------------------------------------------------------------
Epoch 62/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 165.25 8.0
	Loss: 86306.45813 - Accuracy: 0.00% [Time elapsed: 1.84 s]
----------------------------------------------------------------------------------------------------
Epoch 63/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 266.0 66.0
	Loss: 81037.17062 - Accuracy: 0.00% [Time elapsed: 1.83 s]
----------------------------------------------------------------------------------------------------
Epoch 64/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 270.25 196.0
	Loss: 80199.71328 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 65/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 151.25 228.0
	Loss: 100710.49334 - Accuracy: 0.00% [Time elapsed: 1.80 s]
----------------------------------------------------------------------------------------------------
Epoch 66/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 459.0 565.0
	Loss: 92253.16146 - Accuracy: 0.00% [Time elapsed: 1.80 s]
----------------------------------------------------------------------------------------------------
Epoch 67/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 453.0 356.0
	Loss: 94093.81160 - Accuracy: 0.00% [Time elapsed: 1.80 s]
----------------------------------------------------------------------------------------------------
Epoch 68/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 564.0 891.0
	Loss: 91950.08329 - Accuracy: 1.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 69/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 424.5 155.0
	Loss: 88149.85414 - Accuracy: 0.00% [Time elapsed: 1.82 s]
----------------------------------------------------------------------------------------------------
Epoch 70/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 296.25 155.0
	Loss: 84612.66098 - Accuracy: 0.00% [Time elapsed: 1.82 s]
----------------------------------------------------------------------------------------------------
Epoch 71/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 222.25 0.0
	Loss: 84878.36840 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 72/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 345.0 331.0
	Loss: 82422.16763 - Accuracy: 1.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 73/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 138.625 0.0
	Loss: 82298.44961 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 74/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 322.5 155.0
	Loss: 80022.14821 - Accuracy: 0.00% [Time elapsed: 1.80 s]
----------------------------------------------------------------------------------------------------
Epoch 75/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 372.5 638.0
	Loss: 76851.16441 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 76/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 368.5 123.0
	Loss: 75130.91750 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 77/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 398.75 430.0
	Loss: 76448.20473 - Accuracy: 0.00% [Time elapsed: 1.82 s]
----------------------------------------------------------------------------------------------------
Epoch 78/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 196.125 362.0
	Loss: 74433.34703 - Accuracy: 0.00% [Time elapsed: 1.82 s]
----------------------------------------------------------------------------------------------------
Epoch 79/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 603.0 1588.0
	Loss: 74126.57195 - Accuracy: 0.00% [Time elapsed: 1.82 s]
----------------------------------------------------------------------------------------------------
Epoch 80/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 414.75 618.0
	Loss: 75434.39613 - Accuracy: 0.00% [Time elapsed: 1.82 s]
----------------------------------------------------------------------------------------------------
Epoch 81/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 275.5 4.0
	Loss: 73547.37031 - Accuracy: 0.00% [Time elapsed: 1.84 s]
----------------------------------------------------------------------------------------------------
Epoch 82/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 196.25 80.0
	Loss: 72310.09462 - Accuracy: 0.00% [Time elapsed: 1.80 s]
----------------------------------------------------------------------------------------------------
Epoch 83/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 346.25 303.0
	Loss: 71519.62328 - Accuracy: 0.00% [Time elapsed: 1.79 s]
----------------------------------------------------------------------------------------------------
Epoch 84/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 390.25 109.0
	Loss: 75630.27109 - Accuracy: 0.00% [Time elapsed: 1.82 s]
----------------------------------------------------------------------------------------------------
Epoch 85/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 326.5 138.0
	Loss: 73003.79219 - Accuracy: 0.00% [Time elapsed: 1.80 s]
----------------------------------------------------------------------------------------------------
Epoch 86/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 730.0 1099.0
	Loss: 79023.98563 - Accuracy: 0.00% [Time elapsed: 1.82 s]
----------------------------------------------------------------------------------------------------
Epoch 87/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 295.5 145.0
	Loss: 76846.66391 - Accuracy: 0.00% [Time elapsed: 1.82 s]
----------------------------------------------------------------------------------------------------
Epoch 88/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 221.25 1237.0
	Loss: 78436.20844 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 89/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 282.25 204.0
	Loss: 73158.74629 - Accuracy: 0.00% [Time elapsed: 1.80 s]
----------------------------------------------------------------------------------------------------
Epoch 90/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 408.75 155.0
	Loss: 69918.18504 - Accuracy: 0.00% [Time elapsed: 1.80 s]
----------------------------------------------------------------------------------------------------
Epoch 91/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 332.5 213.0
	Loss: 72107.08313 - Accuracy: 0.00% [Time elapsed: 1.80 s]
----------------------------------------------------------------------------------------------------
Epoch 92/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 170.0 85.0
	Loss: 75649.24669 - Accuracy: 0.00% [Time elapsed: 1.80 s]
----------------------------------------------------------------------------------------------------
Epoch 93/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 277.25 126.0
	Loss: 75343.60254 - Accuracy: 0.00% [Time elapsed: 1.79 s]
----------------------------------------------------------------------------------------------------
Epoch 94/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 402.0 358.0
	Loss: 74435.21863 - Accuracy: 0.00% [Time elapsed: 1.81 s]
----------------------------------------------------------------------------------------------------
Epoch 95/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 336.25 138.0
	Loss: 76939.39844 - Accuracy: 0.00% [Time elapsed: 1.80 s]
----------------------------------------------------------------------------------------------------
Epoch 96/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 505.25 757.0
	Loss: 70987.93395 - Accuracy: 0.00% [Time elapsed: 1.80 s]
----------------------------------------------------------------------------------------------------
Epoch 97/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 803.5 1588.0
	Loss: 70131.80020 - Accuracy: 0.00% [Time elapsed: 1.83 s]
----------------------------------------------------------------------------------------------------
Epoch 98/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 104.75 119.0
	Loss: 68041.24418 - Accuracy: 0.00% [Time elapsed: 1.79 s]
----------------------------------------------------------------------------------------------------
Epoch 99/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 180.0 80.0
	Loss: 65621.28402 - Accuracy: 0.00% [Time elapsed: 1.80 s]
----------------------------------------------------------------------------------------------------
Epoch 100/100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

example of prediction vs label: 510.75 963.0
	Loss: 67374.85684 - Accuracy: 0.00% [Time elapsed: 1.80 s]
