In [2]:
import os
import sys
import math
import re
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import spacy
from tqdm.notebook import tqdm
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
from datasets import load_metric

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"using {device}")

using cuda


In [4]:
def get_sent_str(sentence_list):
    sent = " ".join(sentence_list)
    sent = re.sub(r" (?P<punc>[.?,])", r"\1", sent)
    return sent

def get_sent_list(sentences):
    sent_list = []
    for sent in sentences:
        sent_list.append(get_sent_str(sent))
    return sent_list

In [5]:
PRETRAINED_MODEL = 't5-base'
DIR = "question_generator/"
BATCH_SIZE = 1
SEQ_LENGTH = 512
EPOCHS = 100
USE_ANSWER = True
BEST = "toeflqa_finetune_withanswer.pt"
BEST_HF = "toeflqa_finetune_hf_withanswer"

# Check whether the specified path exists or not
if not os.path.exists(DIR):
    os.makedirs(path)

tokenizer = T5Tokenizer.from_pretrained(PRETRAINED_MODEL)
tokenizer.add_special_tokens(
    {'additional_special_tokens': ['<answer>', '<context>']}
)
bertscore = load_metric("bertscore")

In [6]:
import importlib
util = importlib.import_module("data.TOEFL-QA.utils")


def set_fuzzy_context(key, raw_data):
    question = [raw_data[key]["question"]]
    results = []
    for ref in get_sent_list(raw_data[key]["sentences"]):
        results.append(bertscore.compute(predictions=question, references=[ref], lang='en'))
    idx = np.argsort(-1 * np.array([i["precision"] for i in results]).ravel())
    top5 = idx[:5]
    sent_list = get_sent_list(raw_data[key]["sentences"])
#         def get_surrounding(sent_list, ind):
#             if ind == 0:
#                 return set_list[i] + sent_list[i+1]
#             elif ind == len(sent_list)-1:
#                 return set_list[i-1] + sent_list[i]
#             else:
#                 return set_list[i-1] + sent_list[i] + sent_list[i+1]
    raw_data[key]["context"] = " ".join([sent_list[i] for i in sorted(top5)]) # reorder sentences

def preprocess(raw_data):
    keys = list(raw_data.keys())
    for key in tqdm(keys):
        raw_data[key]["question"] = get_sent_str(raw_data[key]["question"])
        raw_data[key]["answer"] = get_sent_str(raw_data[key]["answer"])
        set_fuzzy_context(key, raw_data)


if "train_processed.npy" in os.listdir():
    train_raw = np.load("train_processed.npy", allow_pickle=True).item()
    dev_raw = np.load("dev_processed.npy", allow_pickle=True).item()
    test_raw = np.load("test_processed.npy", allow_pickle=True).item()
else:
    TOEFL_PATH = "./data/TOEFL-QA/data/"
    raw = util.load_data(TOEFL_PATH)
    train_raw, dev_raw, test_raw = tuple(raw)
    preprocess(dev_raw)
    np.save("dev_processed.npy", dev_raw)
    preprocess(test_raw)
    np.save("test_processed.npy", test_raw)
    preprocess(train_raw)
    np.save("train_processed.npy", train_raw)

### Problem
Context for both race and toefl texts / question contexts are not guaranteed to be <512 tokens, which is expected by our T5 base. Also, they both include lots of info not related to the question, which is the normal expectation in previous QG (ex. SQuAD models)
### Possible solutions
- Annotate.
- Use a metric (bertscore) to find semantically similar sentences to a given answer, and take the top n for a "fuzzy" context

In [7]:
class TOEFLDataset(Dataset):
    def __init__(self, data_dict):
        self.data = data_dict
        self.idx_map = list(data_dict.keys())       

    def __len__(self):
         return len(self.idx_map)

    def __getitem__(self, idx):   
        row = self.data[self.idx_map[idx]]
        if USE_ANSWER:
            s = '<answer> ' + row['answer'] + ' <context> '+ row['context']
        else:
            s = row['context']
        encoded_text = tokenizer(
            s, 
            padding=True,
            max_length=SEQ_LENGTH,
            truncation=True,
            return_tensors="pt"
        )
        encoded_text['input_ids'] = torch.squeeze(encoded_text['input_ids'])
        encoded_text['attention_mask'] = torch.squeeze(encoded_text['attention_mask'])

        encoded_question = tokenizer(
            row['question'],
            padding=True,
            max_length=SEQ_LENGTH,
            truncation=True,
            return_tensors='pt'
        )
        encoded_question['input_ids'] = torch.squeeze(encoded_question['input_ids'])

        return encoded_text.to(device), encoded_question.to(device)

    
train_set = TOEFLDataset(train_raw)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
dev_set = TOEFLDataset(dev_raw)
dev_loader = DataLoader(dev_set, batch_size=BATCH_SIZE)


In [8]:

def train(epoch, best_val_loss):
    model.train()
    total_loss = 0.
    for batch_index, batch in tqdm(enumerate(train_loader)):
        data, target = batch 
        # data, target = (data.to(device), target.to(device))
        optimizer.zero_grad()
        masked_labels = mask_label_padding(target['input_ids'])
        output = model(
            input_ids=data['input_ids'],
            attention_mask=data['attention_mask'],
            labels=masked_labels
        )
        output[0].backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        total_loss += output[0].item()
        if batch_index % 500 == 499:
            print(f'| epoch {epoch} | {batch_index}/{len(train_loader)} batches | loss {total_loss / 500}')
            total_loss = 0
def valid(epoch):
    model.eval()
    total_loss = 0.
    for batch_index, batch in tqdm(enumerate(dev_loader)):
        data, target = batch 
        # data, target = (data.to(device), target.to(device))
        optimizer.zero_grad()
        masked_labels = mask_label_padding(target['input_ids'])
        output = model(
            input_ids=data['input_ids'],
            attention_mask=data['attention_mask'],
            labels=masked_labels
        )
        total_loss += output[0].item()
    return total_loss / len(dev_loader)
        
def mask_label_padding(labels):
    MASK_ID = -100
    labels[labels==tokenizer.pad_token_id] = MASK_ID
    return labels


In [9]:
config = T5Config(decoder_start_token_id=tokenizer.pad_token_id) # eos
model = T5ForConditionalGeneration(config).from_pretrained('t5-base')
model.resize_token_embeddings(len(tokenizer)) # to account for new special tokens
model = model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

In [10]:
best_val_loss = float("inf")

for epoch in range(1, EPOCHS + 1):

    train(epoch, best_val_loss)
    torch.cuda.empty_cache()
    val_loss = valid(model)
    torch.cuda.empty_cache()
    print(f'\nend of epoch {epoch}\n valid loss: {val_loss}\n')

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_loss': best_val_loss,
            'using_answer': USE_ANSWER
        }, DIR + BEST + ".best")
        model.save_pretrained(DIR + BEST_HF)
        print("Model saved.\n")
    if epoch % 10 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_loss': best_val_loss,
            'using_answer': USE_ANSWER
        }, DIR + BEST + f".epoch{epoch}")
        model.save_pretrained(DIR + BEST_HF + f".epoch{epoch}")


0it [00:00, ?it/s]

| epoch 1 | 499/717 batches | loss 4.538641332864762


0it [00:00, ?it/s]


end of epoch 1
 valid loss: 4.035245191666387

Model saved.



0it [00:00, ?it/s]

| epoch 2 | 499/717 batches | loss 4.105282034873962


0it [00:00, ?it/s]


end of epoch 2
 valid loss: 3.850341437324401

Model saved.



0it [00:00, ?it/s]

| epoch 3 | 499/717 batches | loss 3.8998029475212097


0it [00:00, ?it/s]


end of epoch 3
 valid loss: 3.727632139959643

Model saved.



0it [00:00, ?it/s]

| epoch 4 | 499/717 batches | loss 3.80651589846611


0it [00:00, ?it/s]


end of epoch 4
 valid loss: 3.627522256105177

Model saved.



0it [00:00, ?it/s]

| epoch 5 | 499/717 batches | loss 3.667395164728165


0it [00:00, ?it/s]


end of epoch 5
 valid loss: 3.5362822673013135

Model saved.



0it [00:00, ?it/s]

| epoch 6 | 499/717 batches | loss 3.6902861633300783


0it [00:00, ?it/s]


end of epoch 6
 valid loss: 3.461399413885609

Model saved.



0it [00:00, ?it/s]

| epoch 7 | 499/717 batches | loss 3.546624441385269


0it [00:00, ?it/s]


end of epoch 7
 valid loss: 3.3940740131562754

Model saved.



0it [00:00, ?it/s]

| epoch 8 | 499/717 batches | loss 3.5129781420230866


0it [00:00, ?it/s]


end of epoch 8
 valid loss: 3.330298476642178

Model saved.



0it [00:00, ?it/s]

| epoch 9 | 499/717 batches | loss 3.5321483438014982


0it [00:00, ?it/s]


end of epoch 9
 valid loss: 3.2724914925713695

Model saved.



0it [00:00, ?it/s]

| epoch 10 | 499/717 batches | loss 3.400990841150284


0it [00:00, ?it/s]


end of epoch 10
 valid loss: 3.2148521590617394

Model saved.



0it [00:00, ?it/s]

| epoch 11 | 499/717 batches | loss 3.377206029653549


0it [00:00, ?it/s]


end of epoch 11
 valid loss: 3.161537802988483

Model saved.



0it [00:00, ?it/s]

| epoch 12 | 499/717 batches | loss 3.3200590505599976


0it [00:00, ?it/s]


end of epoch 12
 valid loss: 3.1114692543783495

Model saved.



0it [00:00, ?it/s]

| epoch 13 | 499/717 batches | loss 3.2852959697246553


0it [00:00, ?it/s]


end of epoch 13
 valid loss: 3.0591880313811766

Model saved.



0it [00:00, ?it/s]

| epoch 14 | 499/717 batches | loss 3.2335167646408083


0it [00:00, ?it/s]


end of epoch 14
 valid loss: 3.011935034105855

Model saved.



0it [00:00, ?it/s]

| epoch 15 | 499/717 batches | loss 3.160422114133835


0it [00:00, ?it/s]


end of epoch 15
 valid loss: 2.9640226085339823

Model saved.



0it [00:00, ?it/s]

| epoch 16 | 499/717 batches | loss 3.1425124335289003


0it [00:00, ?it/s]


end of epoch 16
 valid loss: 2.9151013137832766

Model saved.



0it [00:00, ?it/s]

| epoch 17 | 499/717 batches | loss 3.0986773991584777


0it [00:00, ?it/s]


end of epoch 17
 valid loss: 2.8676502570029228

Model saved.



0it [00:00, ?it/s]

| epoch 18 | 499/717 batches | loss 3.0550667934417723


0it [00:00, ?it/s]


end of epoch 18
 valid loss: 2.8243789201782596

Model saved.



0it [00:00, ?it/s]

| epoch 19 | 499/717 batches | loss 3.0182498788833616


0it [00:00, ?it/s]


end of epoch 19
 valid loss: 2.7774195363444667

Model saved.



0it [00:00, ?it/s]

| epoch 20 | 499/717 batches | loss 2.988192443728447


0it [00:00, ?it/s]


end of epoch 20
 valid loss: 2.7350482700332517

Model saved.



0it [00:00, ?it/s]

| epoch 21 | 499/717 batches | loss 2.941230203151703


0it [00:00, ?it/s]


end of epoch 21
 valid loss: 2.692930104271058

Model saved.



0it [00:00, ?it/s]

| epoch 22 | 499/717 batches | loss 2.8941103732585907


0it [00:00, ?it/s]


end of epoch 22
 valid loss: 2.6557378019056013

Model saved.



0it [00:00, ?it/s]

| epoch 23 | 499/717 batches | loss 2.8412578921318055


0it [00:00, ?it/s]


end of epoch 23
 valid loss: 2.6185262385875947

Model saved.



0it [00:00, ?it/s]

| epoch 24 | 499/717 batches | loss 2.840201823711395


0it [00:00, ?it/s]


end of epoch 24
 valid loss: 2.582727247668851

Model saved.



0it [00:00, ?it/s]

| epoch 25 | 499/717 batches | loss 2.764110788464546


0it [00:00, ?it/s]


end of epoch 25
 valid loss: 2.5477264211062463

Model saved.



0it [00:00, ?it/s]

| epoch 26 | 499/717 batches | loss 2.722447738647461


0it [00:00, ?it/s]


end of epoch 26
 valid loss: 2.514594932236979

Model saved.



0it [00:00, ?it/s]

| epoch 27 | 499/717 batches | loss 2.68292047727108


0it [00:00, ?it/s]


end of epoch 27
 valid loss: 2.4851963207606347

Model saved.



0it [00:00, ?it/s]

| epoch 28 | 499/717 batches | loss 2.686730672836304


0it [00:00, ?it/s]


end of epoch 28
 valid loss: 2.45620801227708

Model saved.



0it [00:00, ?it/s]

| epoch 29 | 499/717 batches | loss 2.6419184033870695


0it [00:00, ?it/s]


end of epoch 29
 valid loss: 2.4292063338141285

Model saved.



0it [00:00, ?it/s]

| epoch 30 | 499/717 batches | loss 2.634173022031784


0it [00:00, ?it/s]


end of epoch 30
 valid loss: 2.403324986177106

Model saved.



0it [00:00, ?it/s]

| epoch 31 | 499/717 batches | loss 2.601015928149223


0it [00:00, ?it/s]


end of epoch 31
 valid loss: 2.382127540246133

Model saved.



0it [00:00, ?it/s]

| epoch 32 | 499/717 batches | loss 2.577909213423729


0it [00:00, ?it/s]


end of epoch 32
 valid loss: 2.3601091643494945

Model saved.



0it [00:00, ?it/s]

| epoch 33 | 499/717 batches | loss 2.5500490970611573


0it [00:00, ?it/s]


end of epoch 33
 valid loss: 2.339824096810433

Model saved.



0it [00:00, ?it/s]

| epoch 34 | 499/717 batches | loss 2.5101222573518753


0it [00:00, ?it/s]


end of epoch 34
 valid loss: 2.3188392885269655

Model saved.



0it [00:00, ?it/s]

| epoch 35 | 499/717 batches | loss 2.519061965227127


0it [00:00, ?it/s]


end of epoch 35
 valid loss: 2.3014374221524885

Model saved.



0it [00:00, ?it/s]

| epoch 36 | 499/717 batches | loss 2.496462462425232


0it [00:00, ?it/s]


end of epoch 36
 valid loss: 2.28392090335969

Model saved.



0it [00:00, ?it/s]

| epoch 37 | 499/717 batches | loss 2.4806521886587145


0it [00:00, ?it/s]


end of epoch 37
 valid loss: 2.2680902322453838

Model saved.



0it [00:00, ?it/s]

| epoch 38 | 499/717 batches | loss 2.455274636030197


0it [00:00, ?it/s]


end of epoch 38
 valid loss: 2.248744133018678

Model saved.



0it [00:00, ?it/s]

| epoch 39 | 499/717 batches | loss 2.45423726272583


0it [00:00, ?it/s]


end of epoch 39
 valid loss: 2.232491780192621

Model saved.



0it [00:00, ?it/s]

| epoch 40 | 499/717 batches | loss 2.4093517248630523


0it [00:00, ?it/s]


end of epoch 40
 valid loss: 2.2167368617749985

Model saved.



0it [00:00, ?it/s]

| epoch 41 | 499/717 batches | loss 2.370382460474968


0it [00:00, ?it/s]


end of epoch 41
 valid loss: 2.2012519399004598

Model saved.



0it [00:00, ?it/s]

| epoch 42 | 499/717 batches | loss 2.3780036413669587


0it [00:00, ?it/s]


end of epoch 42
 valid loss: 2.187485032504605

Model saved.



0it [00:00, ?it/s]

| epoch 43 | 499/717 batches | loss 2.344994927048683


0it [00:00, ?it/s]


end of epoch 43
 valid loss: 2.1727646108596557

Model saved.



0it [00:00, ?it/s]

| epoch 44 | 499/717 batches | loss 2.3535254603624343


0it [00:00, ?it/s]


end of epoch 44
 valid loss: 2.1579449546913945

Model saved.



0it [00:00, ?it/s]

| epoch 45 | 499/717 batches | loss 2.3348680485486986


0it [00:00, ?it/s]


end of epoch 45
 valid loss: 2.1454836870393446

Model saved.



0it [00:00, ?it/s]

| epoch 46 | 499/717 batches | loss 2.3480776466131212


0it [00:00, ?it/s]


end of epoch 46
 valid loss: 2.1316986675223997

Model saved.



0it [00:00, ?it/s]

| epoch 47 | 499/717 batches | loss 2.289084834456444


0it [00:00, ?it/s]


end of epoch 47
 valid loss: 2.11963687308373

Model saved.



0it [00:00, ?it/s]

| epoch 48 | 499/717 batches | loss 2.2851183156967165


0it [00:00, ?it/s]


end of epoch 48
 valid loss: 2.1066328714932165

Model saved.



0it [00:00, ?it/s]

| epoch 49 | 499/717 batches | loss 2.2588912382125854


0it [00:00, ?it/s]


end of epoch 49
 valid loss: 2.096117600798607

Model saved.



0it [00:00, ?it/s]

| epoch 50 | 499/717 batches | loss 2.2617908848524095


0it [00:00, ?it/s]


end of epoch 50
 valid loss: 2.082743039534938

Model saved.



0it [00:00, ?it/s]

| epoch 51 | 499/717 batches | loss 2.2332958716154097


0it [00:00, ?it/s]


end of epoch 51
 valid loss: 2.06874674462503

Model saved.



0it [00:00, ?it/s]

| epoch 52 | 499/717 batches | loss 2.250848911821842


0it [00:00, ?it/s]


end of epoch 52
 valid loss: 2.059017979810315

Model saved.



0it [00:00, ?it/s]

| epoch 53 | 499/717 batches | loss 2.2144177820682525


0it [00:00, ?it/s]


end of epoch 53
 valid loss: 2.0473670695097215

Model saved.



0it [00:00, ?it/s]

| epoch 54 | 499/717 batches | loss 2.198577549338341


0it [00:00, ?it/s]


end of epoch 54
 valid loss: 2.0363061836650296

Model saved.



0it [00:00, ?it/s]

| epoch 55 | 499/717 batches | loss 2.1970655655860902


0it [00:00, ?it/s]


end of epoch 55
 valid loss: 2.028272806156066

Model saved.



0it [00:00, ?it/s]

| epoch 56 | 499/717 batches | loss 2.2222498569488525


0it [00:00, ?it/s]


end of epoch 56
 valid loss: 2.0167443733061514

Model saved.



0it [00:00, ?it/s]

| epoch 57 | 499/717 batches | loss 2.1559309970140457


0it [00:00, ?it/s]


end of epoch 57
 valid loss: 2.0078239051565046

Model saved.



0it [00:00, ?it/s]

| epoch 58 | 499/717 batches | loss 2.1546210602521896


0it [00:00, ?it/s]


end of epoch 58
 valid loss: 1.99763885140419

Model saved.



0it [00:00, ?it/s]

| epoch 59 | 499/717 batches | loss 2.144968629479408


0it [00:00, ?it/s]


end of epoch 59
 valid loss: 1.9894552677869797

Model saved.



0it [00:00, ?it/s]

| epoch 60 | 499/717 batches | loss 2.1602500712871553


0it [00:00, ?it/s]


end of epoch 60
 valid loss: 1.9799850688826652

Model saved.



0it [00:00, ?it/s]

| epoch 61 | 499/717 batches | loss 2.1406464042663575


0it [00:00, ?it/s]


end of epoch 61
 valid loss: 1.9709373481812016

Model saved.



0it [00:00, ?it/s]

| epoch 62 | 499/717 batches | loss 2.1200028413534167


0it [00:00, ?it/s]


end of epoch 62
 valid loss: 1.962822570916145

Model saved.



0it [00:00, ?it/s]

| epoch 63 | 499/717 batches | loss 2.1036755992174148


0it [00:00, ?it/s]


end of epoch 63
 valid loss: 1.9540144267582125

Model saved.



0it [00:00, ?it/s]

| epoch 64 | 499/717 batches | loss 2.100446059823036


0it [00:00, ?it/s]


end of epoch 64
 valid loss: 1.9467865983324666

Model saved.



0it [00:00, ?it/s]

| epoch 65 | 499/717 batches | loss 2.1070330909490584


0it [00:00, ?it/s]


end of epoch 65
 valid loss: 1.9379822376274294

Model saved.



0it [00:00, ?it/s]

| epoch 66 | 499/717 batches | loss 2.0460037440657617


0it [00:00, ?it/s]


end of epoch 66
 valid loss: 1.9297715242831939

Model saved.



0it [00:00, ?it/s]

| epoch 67 | 499/717 batches | loss 2.0889414209127426


0it [00:00, ?it/s]


end of epoch 67
 valid loss: 1.9231510700718049

Model saved.



0it [00:00, ?it/s]

| epoch 68 | 499/717 batches | loss 2.092577376008034


0it [00:00, ?it/s]


end of epoch 68
 valid loss: 1.9153100977982245

Model saved.



0it [00:00, ?it/s]

| epoch 69 | 499/717 batches | loss 2.050533132433891


0it [00:00, ?it/s]


end of epoch 69
 valid loss: 1.9077933012477812

Model saved.



0it [00:00, ?it/s]

| epoch 70 | 499/717 batches | loss 2.003868785381317


0it [00:00, ?it/s]


end of epoch 70
 valid loss: 1.900347335684684

Model saved.



0it [00:00, ?it/s]

| epoch 71 | 499/717 batches | loss 2.0781834441423417


0it [00:00, ?it/s]


end of epoch 71
 valid loss: 1.892648660367535

Model saved.



0it [00:00, ?it/s]

| epoch 72 | 499/717 batches | loss 2.035882868647575


0it [00:00, ?it/s]


end of epoch 72
 valid loss: 1.8854388363899723

Model saved.



0it [00:00, ?it/s]

| epoch 73 | 499/717 batches | loss 1.9983119668960572


0it [00:00, ?it/s]


end of epoch 73
 valid loss: 1.8790765981520376

Model saved.



0it [00:00, ?it/s]

| epoch 74 | 499/717 batches | loss 2.049252517223358


0it [00:00, ?it/s]


end of epoch 74
 valid loss: 1.8735267255575425

Model saved.



0it [00:00, ?it/s]

| epoch 75 | 499/717 batches | loss 1.996967957019806


0it [00:00, ?it/s]


end of epoch 75
 valid loss: 1.8646133128673799

Model saved.



0it [00:00, ?it/s]

| epoch 76 | 499/717 batches | loss 2.0443249456882477


0it [00:00, ?it/s]


end of epoch 76
 valid loss: 1.8580830851870198

Model saved.



0it [00:00, ?it/s]

| epoch 77 | 499/717 batches | loss 2.0220965540409086


0it [00:00, ?it/s]


end of epoch 77
 valid loss: 1.851395441639808

Model saved.



0it [00:00, ?it/s]

| epoch 78 | 499/717 batches | loss 1.9886880830526352


0it [00:00, ?it/s]


end of epoch 78
 valid loss: 1.8461049876866802

Model saved.



0it [00:00, ?it/s]

| epoch 79 | 499/717 batches | loss 1.970353590607643


0it [00:00, ?it/s]


end of epoch 79
 valid loss: 1.8418939385202624

Model saved.



0it [00:00, ?it/s]

| epoch 80 | 499/717 batches | loss 1.9829037055969239


0it [00:00, ?it/s]


end of epoch 80
 valid loss: 1.834889345592068

Model saved.



0it [00:00, ?it/s]

| epoch 81 | 499/717 batches | loss 1.9831677789688111


0it [00:00, ?it/s]


end of epoch 81
 valid loss: 1.829062483243404

Model saved.



0it [00:00, ?it/s]

| epoch 82 | 499/717 batches | loss 1.9630228187441825


0it [00:00, ?it/s]


end of epoch 82
 valid loss: 1.825135997706844

Model saved.



0it [00:00, ?it/s]

| epoch 83 | 499/717 batches | loss 1.9318252145051955


0it [00:00, ?it/s]


end of epoch 83
 valid loss: 1.8210171562048696

Model saved.



0it [00:00, ?it/s]

| epoch 84 | 499/717 batches | loss 1.9705634903907776


0it [00:00, ?it/s]


end of epoch 84
 valid loss: 1.8143139068157441

Model saved.



0it [00:00, ?it/s]

| epoch 85 | 499/717 batches | loss 1.950994043827057


0it [00:00, ?it/s]


end of epoch 85
 valid loss: 1.8073829792199596

Model saved.



0it [00:00, ?it/s]

| epoch 86 | 499/717 batches | loss 1.949674876332283


0it [00:00, ?it/s]


end of epoch 86
 valid loss: 1.8010462813319699

Model saved.



0it [00:00, ?it/s]

| epoch 87 | 499/717 batches | loss 1.921083747625351


0it [00:00, ?it/s]


end of epoch 87
 valid loss: 1.7973634071888462

Model saved.



0it [00:00, ?it/s]

| epoch 88 | 499/717 batches | loss 1.8991777464151383


0it [00:00, ?it/s]


end of epoch 88
 valid loss: 1.7935914483762556

Model saved.



0it [00:00, ?it/s]

| epoch 89 | 499/717 batches | loss 1.9500391551852225


0it [00:00, ?it/s]


end of epoch 89
 valid loss: 1.788841368930955

Model saved.



0it [00:00, ?it/s]

| epoch 90 | 499/717 batches | loss 1.9799502484798432


0it [00:00, ?it/s]


end of epoch 90
 valid loss: 1.7842199929779576

Model saved.



0it [00:00, ?it/s]

| epoch 91 | 499/717 batches | loss 1.895256817817688


0it [00:00, ?it/s]


end of epoch 91
 valid loss: 1.780836617994693

Model saved.



0it [00:00, ?it/s]

| epoch 92 | 499/717 batches | loss 1.89742516374588


0it [00:00, ?it/s]


end of epoch 92
 valid loss: 1.7773417008499945

Model saved.



0it [00:00, ?it/s]

| epoch 93 | 499/717 batches | loss 1.8893733902573586


0it [00:00, ?it/s]


end of epoch 93
 valid loss: 1.773577657678435

Model saved.



0it [00:00, ?it/s]

| epoch 94 | 499/717 batches | loss 1.9076864994764329


0it [00:00, ?it/s]


end of epoch 94
 valid loss: 1.7691536411162345

Model saved.



0it [00:00, ?it/s]

| epoch 95 | 499/717 batches | loss 1.9067354635000229


0it [00:00, ?it/s]


end of epoch 95
 valid loss: 1.764853359951127

Model saved.



0it [00:00, ?it/s]

| epoch 96 | 499/717 batches | loss 1.9124519326090812


0it [00:00, ?it/s]


end of epoch 96
 valid loss: 1.7610191558637927

Model saved.



0it [00:00, ?it/s]

| epoch 97 | 499/717 batches | loss 1.8909940891265868


0it [00:00, ?it/s]


end of epoch 97
 valid loss: 1.7561860408994459

Model saved.



0it [00:00, ?it/s]

| epoch 98 | 499/717 batches | loss 1.8722415781617165


0it [00:00, ?it/s]


end of epoch 98
 valid loss: 1.7528384315871424

Model saved.



0it [00:00, ?it/s]

| epoch 99 | 499/717 batches | loss 1.9022821728587151


0it [00:00, ?it/s]


end of epoch 99
 valid loss: 1.747938297688961

Model saved.



0it [00:00, ?it/s]

| epoch 100 | 499/717 batches | loss 1.8299595975875855


0it [00:00, ?it/s]


end of epoch 100
 valid loss: 1.7447156209138133

Model saved.

