In [5]:
import os
import sys
import math
import re
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import spacy
from tqdm.notebook import tqdm
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
from datasets import load_metric

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"using {device}")

using cuda


In [7]:
def get_sent_str(sentence_list):
    sent = " ".join(sentence_list)
    sent = re.sub(r" (?P<punc>[.?,])", r"\1", sent)
    return sent

def get_sent_list(sentences):
    sent_list = []
    for sent in sentences:
        sent_list.append(get_sent_str(sent))
    return sent_list

In [8]:
PRETRAINED_MODEL = 't5-base'
DIR = "question_generator/toeflqa_finetune_50epoch/"
BATCH_SIZE = 1
SEQ_LENGTH = 512
EPOCHS = 50
USE_ANSWER = False
BEST = "toeflqa_finetune.pt"
BEST_HF = "toeflqa_finetune_hf"

# Check whether the specified path exists or not
if not os.path.exists(DIR):
    os.makedirs(DIR)

tokenizer = T5Tokenizer.from_pretrained(PRETRAINED_MODEL)
tokenizer.add_special_tokens(
    {'additional_special_tokens': ['<answer>', '<context>']}
)
bertscore = load_metric("bertscore")

In [9]:
import importlib
util = importlib.import_module("data.TOEFL-QA.utils")


def set_fuzzy_context(key, raw_data):
    question = [raw_data[key]["question"]]
    results = []
    for ref in get_sent_list(raw_data[key]["sentences"]):
        results.append(bertscore.compute(predictions=question, references=[ref], lang='en'))
    idx = np.argsort(-1 * np.array([i["precision"] for i in results]).ravel())
    top5 = idx[:5]
    sent_list = get_sent_list(raw_data[key]["sentences"])
#         def get_surrounding(sent_list, ind):
#             if ind == 0:
#                 return set_list[i] + sent_list[i+1]
#             elif ind == len(sent_list)-1:
#                 return set_list[i-1] + sent_list[i]
#             else:
#                 return set_list[i-1] + sent_list[i] + sent_list[i+1]
    raw_data[key]["context"] = " ".join([sent_list[i] for i in sorted(top5)]) # reorder sentences

def preprocess(raw_data):
    keys = list(raw_data.keys())
    for key in tqdm(keys):
        raw_data[key]["question"] = get_sent_str(raw_data[key]["question"])
        raw_data[key]["answer"] = get_sent_str(raw_data[key]["answer"])
        set_fuzzy_context(key, raw_data)


if "train_processed.npy" in os.listdir():
    train_raw = np.load("train_processed.npy", allow_pickle=True).item()
    dev_raw = np.load("dev_processed.npy", allow_pickle=True).item()
    test_raw = np.load("test_processed.npy", allow_pickle=True).item()
else:
    TOEFL_PATH = "./data/TOEFL-QA/data/"
    raw = util.load_data(TOEFL_PATH)
    train_raw, dev_raw, test_raw = tuple(raw)
    preprocess(dev_raw)
    np.save("dev_processed.npy", dev_raw)
    preprocess(test_raw)
    np.save("test_processed.npy", test_raw)
    preprocess(train_raw)
    np.save("train_processed.npy", train_raw)

### Problem
Context for both race and toefl texts / question contexts are not guaranteed to be <512 tokens, which is expected by our T5 base. Also, they both include lots of info not related to the question, which is the normal expectation in previous QG (ex. SQuAD models)
### Possible solutions
- Annotate.
- Use a metric (bertscore) to find semantically similar sentences to a given answer, and take the top n for a "fuzzy" context

In [10]:
class TOEFLDataset(Dataset):
    def __init__(self, data_dict):
        self.data = data_dict
        self.idx_map = list(data_dict.keys())       

    def __len__(self):
         return len(self.idx_map)

    def __getitem__(self, idx):   
        row = self.data[self.idx_map[idx]]
        if USE_ANSWER:
            s = '<answer> ' + row['answer'] + ' <context> '+ row['context']
        else:
            s = row['context']
        encoded_text = tokenizer(
            s, 
            padding=True,
            max_length=SEQ_LENGTH,
            truncation=True,
            return_tensors="pt"
        )
        encoded_text['input_ids'] = torch.squeeze(encoded_text['input_ids'])
        encoded_text['attention_mask'] = torch.squeeze(encoded_text['attention_mask'])

        encoded_question = tokenizer(
            row['question'],
            padding=True,
            max_length=SEQ_LENGTH,
            truncation=True,
            return_tensors='pt'
        )
        encoded_question['input_ids'] = torch.squeeze(encoded_question['input_ids'])

        return encoded_text.to(device), encoded_question.to(device)

    
train_set = TOEFLDataset(train_raw)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
dev_set = TOEFLDataset(dev_raw)
dev_loader = DataLoader(dev_set, batch_size=BATCH_SIZE)


In [11]:

def train(epoch, best_val_loss):
    model.train()
    total_loss = 0.
    for batch_index, batch in tqdm(enumerate(train_loader)):
        data, target = batch 
        # data, target = (data.to(device), target.to(device))
        optimizer.zero_grad()
        masked_labels = mask_label_padding(target['input_ids'])
        output = model(
            input_ids=data['input_ids'],
            attention_mask=data['attention_mask'],
            labels=masked_labels
        )
        output[0].backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        total_loss += output[0].item()
        if batch_index % 500 == 499:
            print(f'| epoch {epoch} | {batch_index}/{len(train_loader)} batches | loss {total_loss / 500}')
            total_loss = 0
def valid(epoch):
    model.eval()
    total_loss = 0.
    for batch_index, batch in tqdm(enumerate(dev_loader)):
        data, target = batch 
        # data, target = (data.to(device), target.to(device))
        optimizer.zero_grad()
        masked_labels = mask_label_padding(target['input_ids'])
        output = model(
            input_ids=data['input_ids'],
            attention_mask=data['attention_mask'],
            labels=masked_labels
        )
        total_loss += output[0].item()
    return total_loss / len(dev_loader)
        
def mask_label_padding(labels):
    MASK_ID = -100
    labels[labels==tokenizer.pad_token_id] = MASK_ID
    return labels


In [12]:
config = T5Config(decoder_start_token_id=tokenizer.pad_token_id) # eos
model = T5ForConditionalGeneration(config).from_pretrained('t5-base')
model.resize_token_embeddings(len(tokenizer)) # to account for new special tokens
model = model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

In [13]:
best_val_loss = float("inf")

for epoch in range(1, EPOCHS + 1):

    train(epoch, best_val_loss)
    torch.cuda.empty_cache()
    val_loss = valid(model)
    torch.cuda.empty_cache()
    print(f'\nend of epoch {epoch}\n valid loss: {val_loss}\n')

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_loss': best_val_loss,
            'using_answer': USE_ANSWER
        }, DIR + BEST + ".best")
        model.save_pretrained(DIR + BEST_HF)
        print("Model saved.\n")
    if epoch % 10 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_loss': best_val_loss,
            'using_answer': USE_ANSWER
        }, DIR + BEST + f".epoch{epoch}")
        model.save_pretrained(DIR + BEST_HF + f".epoch{epoch}")


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 1 | 499/717 batches | loss 4.689216994524002



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 1
 valid loss: 4.153259461925876

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 2 | 499/717 batches | loss 4.160969120979309



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 2
 valid loss: 3.96985526911674

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 3 | 499/717 batches | loss 3.976770967960358



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 3
 valid loss: 3.843560117867685

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 4 | 499/717 batches | loss 3.904599846363068



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 4
 valid loss: 3.7447966404499544

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 5 | 499/717 batches | loss 3.7769359216690064



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 5
 valid loss: 3.6629366711262734

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 6 | 499/717 batches | loss 3.6910454866886138



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 6
 valid loss: 3.5923522133981027

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 7 | 499/717 batches | loss 3.656113076210022



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 7
 valid loss: 3.5271683921737056

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 8 | 499/717 batches | loss 3.624266453027725



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 8
 valid loss: 3.467756860679196

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 9 | 499/717 batches | loss 3.5308442780971525



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 9
 valid loss: 3.4052160997544565

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 10 | 499/717 batches | loss 3.5089446218013762



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 10
 valid loss: 3.343626929867652

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 11 | 499/717 batches | loss 3.4588372704982757



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 11
 valid loss: 3.2862458132928416

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 12 | 499/717 batches | loss 3.3992341096401213



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 12
 valid loss: 3.2235606283910814

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 13 | 499/717 batches | loss 3.3091200115680697



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 13
 valid loss: 3.1702171093033207

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 14 | 499/717 batches | loss 3.3096687240600584



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 14
 valid loss: 3.1178105435063763

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 15 | 499/717 batches | loss 3.254407828092575



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 15
 valid loss: 3.0698092464477784

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 16 | 499/717 batches | loss 3.192989155769348



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 16
 valid loss: 3.0204951186333933

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 17 | 499/717 batches | loss 3.16435785984993



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 17
 valid loss: 2.975689322717728

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 18 | 499/717 batches | loss 3.1248135888576507



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 18
 valid loss: 2.935489003696749

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 19 | 499/717 batches | loss 3.1091285285949706



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 19
 valid loss: 2.893414206081821

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 20 | 499/717 batches | loss 3.01478049659729



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 20
 valid loss: 2.854168768851988

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 21 | 499/717 batches | loss 2.983701623916626



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 21
 valid loss: 2.8171825235889805

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 22 | 499/717 batches | loss 2.928907448530197



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 22
 valid loss: 2.7774456489470696

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 23 | 499/717 batches | loss 2.9096244103908537



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 23
 valid loss: 2.7460719962273874

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 24 | 499/717 batches | loss 2.8730768473148345



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 24
 valid loss: 2.7130833902666645

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 25 | 499/717 batches | loss 2.890997621178627



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 25
 valid loss: 2.6895058356946513

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 26 | 499/717 batches | loss 2.8466314625740052



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 26
 valid loss: 2.6572699950587366

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 27 | 499/717 batches | loss 2.8373003430366515



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 27
 valid loss: 2.63598580321958

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 28 | 499/717 batches | loss 2.800928282499313



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 28
 valid loss: 2.6062524011058192

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 29 | 499/717 batches | loss 2.7829314436912536



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 29
 valid loss: 2.5852173432227104

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 30 | 499/717 batches | loss 2.740534744977951



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 30
 valid loss: 2.5632630995204373

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 31 | 499/717 batches | loss 2.6969112257957457



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 31
 valid loss: 2.541425550656934

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 32 | 499/717 batches | loss 2.6860649000406265



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 32
 valid loss: 2.5219216356354375

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 33 | 499/717 batches | loss 2.6556204439401627



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 33
 valid loss: 2.4966114158591917

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 34 | 499/717 batches | loss 2.6481170970201493



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 34
 valid loss: 2.486057878021271

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 35 | 499/717 batches | loss 2.647343913435936



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 35
 valid loss: 2.4649335530496415

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 36 | 499/717 batches | loss 2.615205845594406



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 36
 valid loss: 2.4505757410680093

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 37 | 499/717 batches | loss 2.581712883591652



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 37
 valid loss: 2.434085246055357

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 38 | 499/717 batches | loss 2.5995835819244384



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 38
 valid loss: 2.416294053196907

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 39 | 499/717 batches | loss 2.6000881142616272



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 39
 valid loss: 2.4058116736911956

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 40 | 499/717 batches | loss 2.5603927417993546



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 40
 valid loss: 2.387182001625338

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 41 | 499/717 batches | loss 2.502403570652008



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 41
 valid loss: 2.3662302013366454

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 42 | 499/717 batches | loss 2.5466279537677763



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 42
 valid loss: 2.3622337459556517

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 43 | 499/717 batches | loss 2.4841026445627215



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 43
 valid loss: 2.346235811229675

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 44 | 499/717 batches | loss 2.515434943318367



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 44
 valid loss: 2.328579575304062

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 45 | 499/717 batches | loss 2.4991085435152054



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 45
 valid loss: 2.310061188474778

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 46 | 499/717 batches | loss 2.475265266776085



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 46
 valid loss: 2.2939956827509786

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 47 | 499/717 batches | loss 2.442198328256607



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 47
 valid loss: 2.2882378053280616

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 48 | 499/717 batches | loss 2.4456940335035324



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 48
 valid loss: 2.275715033854208

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 49 | 499/717 batches | loss 2.424694144487381



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 49
 valid loss: 2.2605835886732226

Model saved.



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

| epoch 50 | 499/717 batches | loss 2.4235960644483567



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



end of epoch 50
 valid loss: 2.2439743179467415

Model saved.

