In [1]:
import os
import sys
import math
import re
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import spacy
from tqdm.notebook import tqdm
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
from datasets import load_metric

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"using {device}")

using cuda


In [3]:
def get_sent_str(sentence_list):
    sent = " ".join(sentence_list)
    sent = re.sub(r" (?P<punc>[.?,])", r"\1", sent)
    return sent

def get_sent_list(sentences):
    sent_list = []
    for sent in sentences:
        sent_list.append(get_sent_str(sent))
    return sent_list

In [13]:
PRETRAINED_MODEL = 't5-base'
DIR = "question_generator/"
BATCH_SIZE = 1
SEQ_LENGTH = 512
EPOCHS = 200

tokenizer = T5Tokenizer.from_pretrained(PRETRAINED_MODEL)
tokenizer.add_special_tokens(
    {'additional_special_tokens': ['<answer>', '<context>']}
)
sacrebleu = load_metric("sacrebleu")
bertscore = load_metric("bertscore")

In [14]:
import importlib
util = importlib.import_module("data.TOEFL-QA.utils")


def set_fuzzy_context(key, raw_data):
    question = [raw_data[key]["question"]]
    results = []
    for ref in get_sent_list(raw_data[key]["sentences"]):
        results.append(bertscore.compute(predictions=question, references=[ref], lang='en'))
    idx = np.argsort(-1 * np.array([i["precision"] for i in results]).ravel())
    top5 = idx[:5]
    sent_list = get_sent_list(raw_data[key]["sentences"])
#         def get_surrounding(sent_list, ind):
#             if ind == 0:
#                 return set_list[i] + sent_list[i+1]
#             elif ind == len(sent_list)-1:
#                 return set_list[i-1] + sent_list[i]
#             else:
#                 return set_list[i-1] + sent_list[i] + sent_list[i+1]
    raw_data[key]["context"] = " ".join([sent_list[i] for i in sorted(top5)]) # reorder sentences

def preprocess(raw_data):
    keys = list(raw_data.keys())
    for key in tqdm(keys):
        raw_data[key]["question"] = get_sent_str(raw_data[key]["question"])
        set_fuzzy_context(key, raw_data)


if "train_processed.npy" in os.listdir():
    train_raw = np.load("train_processed.npy", allow_pickle=True).item()
    dev_raw = np.load("dev_processed.npy", allow_pickle=True).item()
    test_raw = np.load("test_processed.npy", allow_pickle=True).item()
else:
    TOEFL_PATH = "./data/TOEFL-QA/data/"
    raw = util.load_data(TOEFL_PATH)
    train_raw, dev_raw, test_raw = tuple(raw)
    preprocess(dev_raw)
    np.save("dev_processed.npy", dev_raw)
    preprocess(test_raw)
    np.save("test_processed.npy", test_raw)
    preprocess(train_raw)
    np.save("train_processed.npy", train_raw)

In [15]:
train_raw.keys()
def get_sentences_len(key):
    x = []
    for sent in train_raw[key]['sentences']:
        x.extend(sent)
    # tokenized = tokenizer(" ".join(x))
    return len(x)
get_sentences_len('tpo_1-conversation_1_2')

504

### Problem
Context for both race and toefl texts / question contexts are not guaranteed to be <512 tokens, which is expected by our T5 base. Also, they both include lots of info not related to the question, which is the normal expectation in previous QG (ex. SQuAD models)
### Possible solutions
- Annotate.
- Use a metric (bertscore) to find semantically similar sentences to a given answer, and take the top n for a "fuzzy" context

In [16]:
class TOEFLDataset(Dataset):
    def __init__(self, data_dict):
        self.data = data_dict
        self.idx_map = list(data_dict.keys())       

    def __len__(self):
         return len(self.idx_map)

    def __getitem__(self, idx):   
        row = self.data[self.idx_map[idx]]  

        encoded_text = tokenizer(
            row['context'], 
            padding=True,
            max_length=SEQ_LENGTH,
            truncation=True,
            return_tensors="pt"
        )
        encoded_text['input_ids'] = torch.squeeze(encoded_text['input_ids'])
        encoded_text['attention_mask'] = torch.squeeze(encoded_text['attention_mask'])

        encoded_question = tokenizer(
            row['question'],
            padding=True,
            max_length=SEQ_LENGTH,
            truncation=True,
            return_tensors='pt'
        )
        encoded_question['input_ids'] = torch.squeeze(encoded_question['input_ids'])

        return encoded_text.to(device), encoded_question.to(device)

    
train_set = TOEFLDataset(train_raw)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
dev_set = TOEFLDataset(dev_raw)
dev_loader = DataLoader(dev_set, batch_size=BATCH_SIZE)


In [17]:

def train(epoch, best_val_loss):
    model.train()
    total_loss = 0.
    for batch_index, batch in tqdm(enumerate(train_loader)):
        data, target = batch 
        # data, target = (data.to(device), target.to(device))
        optimizer.zero_grad()
        masked_labels = mask_label_padding(target['input_ids'])
        output = model(
            input_ids=data['input_ids'],
            attention_mask=data['attention_mask'],
            labels=masked_labels
        )
        output[0].backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        total_loss += output[0].item()
        if batch_index % 500 == 499:
            print(f'| epoch {epoch} | {batch_index}/{len(train_loader)} batches | loss {total_loss / 500}')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_loss': best_val_loss,
            }, "question_generator/toeflqa_finetune_tmp.pth")
            total_loss = 0
def valid(epoch):
    model.eval()
    total_loss = 0.
    for batch_index, batch in tqdm(enumerate(dev_loader)):
        data, target = batch 
        # data, target = (data.to(device), target.to(device))
        optimizer.zero_grad()
        masked_labels = mask_label_padding(target['input_ids'])
        output = model(
            input_ids=data['input_ids'],
            attention_mask=data['attention_mask'],
            labels=masked_labels
        )
        total_loss += output[0].item()
    return total_loss / len(dev_loader)
        
def mask_label_padding(labels):
    MASK_ID = -100
    labels[labels==tokenizer.pad_token_id] = MASK_ID
    return labels


In [18]:
config = T5Config(decoder_start_token_id=tokenizer.pad_token_id) # eos
model = T5ForConditionalGeneration(config).from_pretrained('t5-base')
model.resize_token_embeddings(len(tokenizer)) # to account for new special tokens
model = model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

In [19]:
best_val_loss = float("inf")

for epoch in range(1, EPOCHS + 1):

    train(epoch, best_val_loss)
    torch.cuda.empty_cache()
    val_loss = valid(model)
    torch.cuda.empty_cache()
    print(f'\nend of epoch {epoch}\n valid loss: {val_loss}\n')

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_loss': best_val_loss,
        }, "question_generator/toeflqa_finetune.pth")
        print("Model saved.\n")


0it [00:00, ?it/s]

| epoch 1 | 499/717 batches | loss 4.510052051305771


0it [00:00, ?it/s]


end of epoch 1
 valid loss: 4.156029810828548

Model saved.



0it [00:00, ?it/s]

| epoch 2 | 499/717 batches | loss 4.128188771009445


0it [00:00, ?it/s]


end of epoch 2
 valid loss: 3.9716839607684844

Model saved.



0it [00:00, ?it/s]

| epoch 3 | 499/717 batches | loss 3.9635279214382173


0it [00:00, ?it/s]


end of epoch 3
 valid loss: 3.8432628425859634

Model saved.



0it [00:00, ?it/s]

| epoch 4 | 499/717 batches | loss 3.8591274704933167


0it [00:00, ?it/s]


end of epoch 4
 valid loss: 3.7437118926355915

Model saved.



0it [00:00, ?it/s]

| epoch 5 | 499/717 batches | loss 3.766662883520126


0it [00:00, ?it/s]


end of epoch 5
 valid loss: 3.664342273627558

Model saved.



0it [00:00, ?it/s]

| epoch 6 | 499/717 batches | loss 3.737801780939102


0it [00:00, ?it/s]


end of epoch 6
 valid loss: 3.5916650160666435

Model saved.



0it [00:00, ?it/s]

| epoch 7 | 499/717 batches | loss 3.6807423946857454


0it [00:00, ?it/s]


end of epoch 7
 valid loss: 3.5304340229880427

Model saved.



0it [00:00, ?it/s]

| epoch 8 | 499/717 batches | loss 3.6520659070014956


0it [00:00, ?it/s]


end of epoch 8
 valid loss: 3.462945685271294

Model saved.



0it [00:00, ?it/s]

| epoch 9 | 499/717 batches | loss 3.5667401208877565


0it [00:00, ?it/s]


end of epoch 9
 valid loss: 3.401242256164551

Model saved.



0it [00:00, ?it/s]

| epoch 10 | 499/717 batches | loss 3.4863510549068453


0it [00:00, ?it/s]


end of epoch 10
 valid loss: 3.3421378664432035

Model saved.



0it [00:00, ?it/s]

| epoch 11 | 499/717 batches | loss 3.442638862133026


0it [00:00, ?it/s]


end of epoch 11
 valid loss: 3.280492793167791

Model saved.



0it [00:00, ?it/s]

| epoch 12 | 499/717 batches | loss 3.3762603905200956


0it [00:00, ?it/s]


end of epoch 12
 valid loss: 3.226461298042728

Model saved.



0it [00:00, ?it/s]

| epoch 13 | 499/717 batches | loss 3.2898035988807677


0it [00:00, ?it/s]


end of epoch 13
 valid loss: 3.1711697049679293

Model saved.



0it [00:00, ?it/s]

| epoch 14 | 499/717 batches | loss 3.2849873225688935


0it [00:00, ?it/s]


end of epoch 14
 valid loss: 3.1156374869808072

Model saved.



0it [00:00, ?it/s]

| epoch 15 | 499/717 batches | loss 3.2919903218746187


0it [00:00, ?it/s]


end of epoch 15
 valid loss: 3.0605943452927376

Model saved.



0it [00:00, ?it/s]

| epoch 16 | 499/717 batches | loss 3.171148963689804


0it [00:00, ?it/s]


end of epoch 16
 valid loss: 3.015739280369974

Model saved.



0it [00:00, ?it/s]

| epoch 17 | 499/717 batches | loss 3.144950766801834


0it [00:00, ?it/s]


end of epoch 17
 valid loss: 2.9708076773151273

Model saved.



0it [00:00, ?it/s]

| epoch 18 | 499/717 batches | loss 3.114379122257233


0it [00:00, ?it/s]


end of epoch 18
 valid loss: 2.9336923793438943

Model saved.



0it [00:00, ?it/s]

| epoch 19 | 499/717 batches | loss 3.0259915845394136


0it [00:00, ?it/s]


end of epoch 19
 valid loss: 2.8901170021103275

Model saved.



0it [00:00, ?it/s]

| epoch 20 | 499/717 batches | loss 3.1104607380628586


0it [00:00, ?it/s]


end of epoch 20
 valid loss: 2.85681307796509

Model saved.



0it [00:00, ?it/s]

| epoch 21 | 499/717 batches | loss 2.99414990735054


0it [00:00, ?it/s]


end of epoch 21
 valid loss: 2.8106576242754535

Model saved.



0it [00:00, ?it/s]

| epoch 22 | 499/717 batches | loss 2.9935269548892975


0it [00:00, ?it/s]


end of epoch 22
 valid loss: 2.781324343335244

Model saved.



0it [00:00, ?it/s]

| epoch 23 | 499/717 batches | loss 2.942680755019188


0it [00:00, ?it/s]


end of epoch 23
 valid loss: 2.751000165939331

Model saved.



0it [00:00, ?it/s]

| epoch 24 | 499/717 batches | loss 2.89762597489357


0it [00:00, ?it/s]


end of epoch 24
 valid loss: 2.7257120936147627

Model saved.



0it [00:00, ?it/s]

| epoch 25 | 499/717 batches | loss 2.91541782104969


0it [00:00, ?it/s]


end of epoch 25
 valid loss: 2.692049133200799

Model saved.



0it [00:00, ?it/s]

| epoch 26 | 499/717 batches | loss 2.84871743619442


0it [00:00, ?it/s]


end of epoch 26
 valid loss: 2.6612553932974414

Model saved.



0it [00:00, ?it/s]

| epoch 27 | 499/717 batches | loss 2.8258685202598572


0it [00:00, ?it/s]


end of epoch 27
 valid loss: 2.6478444030207973

Model saved.



0it [00:00, ?it/s]

| epoch 28 | 499/717 batches | loss 2.754151295185089


0it [00:00, ?it/s]


end of epoch 28
 valid loss: 2.633017236186612

Model saved.



0it [00:00, ?it/s]

| epoch 29 | 499/717 batches | loss 2.7920896768569947


0it [00:00, ?it/s]


end of epoch 29
 valid loss: 2.5938494340065987

Model saved.



0it [00:00, ?it/s]

| epoch 30 | 499/717 batches | loss 2.7770748398303984


0it [00:00, ?it/s]


end of epoch 30
 valid loss: 2.5711588301966266

Model saved.



0it [00:00, ?it/s]

| epoch 31 | 499/717 batches | loss 2.736102422595024


0it [00:00, ?it/s]


end of epoch 31
 valid loss: 2.5530131959146067

Model saved.



0it [00:00, ?it/s]

| epoch 32 | 499/717 batches | loss 2.7105609934329986


0it [00:00, ?it/s]


end of epoch 32
 valid loss: 2.528872140953618

Model saved.



0it [00:00, ?it/s]

| epoch 33 | 499/717 batches | loss 2.6587178170681


0it [00:00, ?it/s]


end of epoch 33
 valid loss: 2.509306880735582

Model saved.



0it [00:00, ?it/s]

| epoch 34 | 499/717 batches | loss 2.697130980849266


0it [00:00, ?it/s]


end of epoch 34
 valid loss: 2.4881860964721247

Model saved.



0it [00:00, ?it/s]

| epoch 35 | 499/717 batches | loss 2.628510082960129


0it [00:00, ?it/s]


end of epoch 35
 valid loss: 2.4703883538323064

Model saved.



0it [00:00, ?it/s]

| epoch 36 | 499/717 batches | loss 2.6377760623693467


0it [00:00, ?it/s]


end of epoch 36
 valid loss: 2.4617321097081706

Model saved.



0it [00:00, ?it/s]

| epoch 37 | 499/717 batches | loss 2.6295458718538285


0it [00:00, ?it/s]


end of epoch 37
 valid loss: 2.439993429568506

Model saved.



0it [00:00, ?it/s]

| epoch 38 | 499/717 batches | loss 2.6018658777475356


0it [00:00, ?it/s]


end of epoch 38
 valid loss: 2.4233730916054017

Model saved.



0it [00:00, ?it/s]

| epoch 39 | 499/717 batches | loss 2.5818732279539107


0it [00:00, ?it/s]


end of epoch 39
 valid loss: 2.4035555479987973

Model saved.



0it [00:00, ?it/s]

| epoch 40 | 499/717 batches | loss 2.5533215779066087


0it [00:00, ?it/s]


end of epoch 40
 valid loss: 2.3978344668303766

Model saved.



0it [00:00, ?it/s]

| epoch 41 | 499/717 batches | loss 2.534352411746979


0it [00:00, ?it/s]


end of epoch 41
 valid loss: 2.3811501701993327

Model saved.



0it [00:00, ?it/s]

| epoch 42 | 499/717 batches | loss 2.5203198694586755


0it [00:00, ?it/s]


end of epoch 42
 valid loss: 2.366325792285704

Model saved.



0it [00:00, ?it/s]

| epoch 43 | 499/717 batches | loss 2.567799260735512


0it [00:00, ?it/s]


end of epoch 43
 valid loss: 2.3489544261847772

Model saved.



0it [00:00, ?it/s]

| epoch 44 | 499/717 batches | loss 2.5229222401380538


0it [00:00, ?it/s]


end of epoch 44
 valid loss: 2.333231546705769

Model saved.



0it [00:00, ?it/s]

| epoch 45 | 499/717 batches | loss 2.498352764606476


0it [00:00, ?it/s]


end of epoch 45
 valid loss: 2.320905909422905

Model saved.



0it [00:00, ?it/s]

| epoch 46 | 499/717 batches | loss 2.466204426884651


0it [00:00, ?it/s]


end of epoch 46
 valid loss: 2.298273172109358

Model saved.



0it [00:00, ?it/s]

| epoch 47 | 499/717 batches | loss 2.4547005054950715


0it [00:00, ?it/s]


end of epoch 47
 valid loss: 2.286314063975888

Model saved.



0it [00:00, ?it/s]

| epoch 48 | 499/717 batches | loss 2.4377678985595703


0it [00:00, ?it/s]


end of epoch 48
 valid loss: 2.275740163941537

Model saved.



0it [00:00, ?it/s]

| epoch 49 | 499/717 batches | loss 2.4210245608091356


0it [00:00, ?it/s]


end of epoch 49
 valid loss: 2.257202104214699

Model saved.



0it [00:00, ?it/s]

| epoch 50 | 499/717 batches | loss 2.408193953514099


0it [00:00, ?it/s]


end of epoch 50
 valid loss: 2.2486771917150867

Model saved.



0it [00:00, ?it/s]

| epoch 51 | 499/717 batches | loss 2.426490233540535


0it [00:00, ?it/s]


end of epoch 51
 valid loss: 2.2426746473197015

Model saved.



0it [00:00, ?it/s]

| epoch 52 | 499/717 batches | loss 2.397951184749603


0it [00:00, ?it/s]


end of epoch 52
 valid loss: 2.227567031979561

Model saved.



0it [00:00, ?it/s]

| epoch 53 | 499/717 batches | loss 2.362931210398674


0it [00:00, ?it/s]


end of epoch 53
 valid loss: 2.21547282847666

Model saved.



0it [00:00, ?it/s]

| epoch 54 | 499/717 batches | loss 2.398617545962334


0it [00:00, ?it/s]


end of epoch 54
 valid loss: 2.2035631503789657

Model saved.



0it [00:00, ?it/s]

| epoch 55 | 499/717 batches | loss 2.382297125220299


0it [00:00, ?it/s]


end of epoch 55
 valid loss: 2.200275225985435

Model saved.



0it [00:00, ?it/s]

| epoch 56 | 499/717 batches | loss 2.369173450827599


0it [00:00, ?it/s]


end of epoch 56
 valid loss: 2.1816405105975365

Model saved.



0it [00:00, ?it/s]

| epoch 57 | 499/717 batches | loss 2.320723517894745


0it [00:00, ?it/s]


end of epoch 57
 valid loss: 2.1724525060384505

Model saved.



0it [00:00, ?it/s]

| epoch 58 | 499/717 batches | loss 2.3213274064064024


0it [00:00, ?it/s]


end of epoch 58
 valid loss: 2.1642645376343883

Model saved.



0it [00:00, ?it/s]

| epoch 59 | 499/717 batches | loss 2.280716080069542


0it [00:00, ?it/s]


end of epoch 59
 valid loss: 2.154710277915001

Model saved.



0it [00:00, ?it/s]

| epoch 60 | 499/717 batches | loss 2.2610679390430453


0it [00:00, ?it/s]


end of epoch 60
 valid loss: 2.138645040412103

Model saved.



0it [00:00, ?it/s]

| epoch 61 | 499/717 batches | loss 2.3014080353975297


0it [00:00, ?it/s]


end of epoch 61
 valid loss: 2.1308912702145113

Model saved.



0it [00:00, ?it/s]

| epoch 62 | 499/717 batches | loss 2.3098448647260668


0it [00:00, ?it/s]


end of epoch 62
 valid loss: 2.117645217045661

Model saved.



0it [00:00, ?it/s]

| epoch 63 | 499/717 batches | loss 2.260449944615364


0it [00:00, ?it/s]


end of epoch 63
 valid loss: 2.1109506564755596

Model saved.



0it [00:00, ?it/s]

| epoch 64 | 499/717 batches | loss 2.2939429384469987


0it [00:00, ?it/s]


end of epoch 64
 valid loss: 2.1076511245581413

Model saved.



0it [00:00, ?it/s]

| epoch 65 | 499/717 batches | loss 2.2627544811964033


0it [00:00, ?it/s]


end of epoch 65
 valid loss: 2.0960895246075046

Model saved.



0it [00:00, ?it/s]

| epoch 66 | 499/717 batches | loss 2.271068125426769


0it [00:00, ?it/s]


end of epoch 66
 valid loss: 2.088312638382758

Model saved.



0it [00:00, ?it/s]

| epoch 67 | 499/717 batches | loss 2.2121461458206175


0it [00:00, ?it/s]


end of epoch 67
 valid loss: 2.0770405872214224

Model saved.



0it [00:00, ?it/s]

| epoch 68 | 499/717 batches | loss 2.213873530745506


0it [00:00, ?it/s]


end of epoch 68
 valid loss: 2.066656264566606

Model saved.



0it [00:00, ?it/s]

| epoch 69 | 499/717 batches | loss 2.2264406936168673


0it [00:00, ?it/s]


end of epoch 69
 valid loss: 2.0637475208890055

Model saved.



0it [00:00, ?it/s]

| epoch 70 | 499/717 batches | loss 2.1913715463876726


0it [00:00, ?it/s]


end of epoch 70
 valid loss: 2.050509111054482

Model saved.



0it [00:00, ?it/s]

| epoch 71 | 499/717 batches | loss 2.1843788386583327


0it [00:00, ?it/s]


end of epoch 71
 valid loss: 2.046802490468948

Model saved.



0it [00:00, ?it/s]

| epoch 72 | 499/717 batches | loss 2.1865647956132888


0it [00:00, ?it/s]


end of epoch 72
 valid loss: 2.034256202078635

Model saved.



0it [00:00, ?it/s]

| epoch 73 | 499/717 batches | loss 2.1815056593418123


0it [00:00, ?it/s]


end of epoch 73
 valid loss: 2.0300684229981516

Model saved.



0it [00:00, ?it/s]

| epoch 74 | 499/717 batches | loss 2.146151007294655


0it [00:00, ?it/s]


end of epoch 74
 valid loss: 2.0202117538259876

Model saved.



0it [00:00, ?it/s]

| epoch 75 | 499/717 batches | loss 2.172693804383278


0it [00:00, ?it/s]


end of epoch 75
 valid loss: 2.0131233517200715

Model saved.



0it [00:00, ?it/s]

| epoch 76 | 499/717 batches | loss 2.167017232596874


0it [00:00, ?it/s]


end of epoch 76
 valid loss: 2.0049516472124282

Model saved.



0it [00:00, ?it/s]

| epoch 77 | 499/717 batches | loss 2.1347506455779075


0it [00:00, ?it/s]


end of epoch 77
 valid loss: 1.997099358227945

Model saved.



0it [00:00, ?it/s]

| epoch 78 | 499/717 batches | loss 2.1293803831338884


0it [00:00, ?it/s]


end of epoch 78
 valid loss: 1.9941915864906004

Model saved.



0it [00:00, ?it/s]

| epoch 79 | 499/717 batches | loss 2.1335576934814453


0it [00:00, ?it/s]


end of epoch 79
 valid loss: 1.9890037467402797

Model saved.



0it [00:00, ?it/s]

| epoch 80 | 499/717 batches | loss 2.1311617546081543


0it [00:00, ?it/s]


end of epoch 80
 valid loss: 1.9832665963519005

Model saved.



0it [00:00, ?it/s]

| epoch 81 | 499/717 batches | loss 2.111201849400997


0it [00:00, ?it/s]


end of epoch 81
 valid loss: 1.974227356814569

Model saved.



0it [00:00, ?it/s]

| epoch 82 | 499/717 batches | loss 2.1227848477959634


0it [00:00, ?it/s]


end of epoch 82
 valid loss: 1.968700152731711

Model saved.



0it [00:00, ?it/s]

| epoch 83 | 499/717 batches | loss 2.1066171692609785


0it [00:00, ?it/s]


end of epoch 83
 valid loss: 1.9650618030178932

Model saved.



0it [00:00, ?it/s]

| epoch 84 | 499/717 batches | loss 2.1003383309841155


0it [00:00, ?it/s]


end of epoch 84
 valid loss: 1.9568796244359785

Model saved.



0it [00:00, ?it/s]

| epoch 85 | 499/717 batches | loss 2.052569935441017


0it [00:00, ?it/s]


end of epoch 85
 valid loss: 1.948140273651769

Model saved.



0it [00:00, ?it/s]

| epoch 86 | 499/717 batches | loss 2.1079785068035126


0it [00:00, ?it/s]


end of epoch 86
 valid loss: 1.9434882661988657

Model saved.



0it [00:00, ?it/s]

| epoch 87 | 499/717 batches | loss 2.0544270939826967


0it [00:00, ?it/s]


end of epoch 87
 valid loss: 1.9399458434312575

Model saved.



0it [00:00, ?it/s]

| epoch 88 | 499/717 batches | loss 2.074507597506046


0it [00:00, ?it/s]


end of epoch 88
 valid loss: 1.938493577703353

Model saved.



0it [00:00, ?it/s]

| epoch 89 | 499/717 batches | loss 2.099944856464863


0it [00:00, ?it/s]


end of epoch 89
 valid loss: 1.9326449565349086

Model saved.



0it [00:00, ?it/s]

| epoch 90 | 499/717 batches | loss 2.041440788984299


0it [00:00, ?it/s]


end of epoch 90
 valid loss: 1.9232668237340065

Model saved.



0it [00:00, ?it/s]

| epoch 91 | 499/717 batches | loss 2.0180318384170532


0it [00:00, ?it/s]


end of epoch 91
 valid loss: 1.9201313540820153

Model saved.



0it [00:00, ?it/s]

| epoch 92 | 499/717 batches | loss 2.035712058007717


0it [00:00, ?it/s]


end of epoch 92
 valid loss: 1.9136758451500246

Model saved.



0it [00:00, ?it/s]

| epoch 93 | 499/717 batches | loss 2.0001035706996917


0it [00:00, ?it/s]


end of epoch 93
 valid loss: 1.9106699002365912

Model saved.



0it [00:00, ?it/s]

| epoch 94 | 499/717 batches | loss 2.0119805621504785


0it [00:00, ?it/s]


end of epoch 94
 valid loss: 1.9050968822933012

Model saved.



0it [00:00, ?it/s]

| epoch 95 | 499/717 batches | loss 2.024161042034626


0it [00:00, ?it/s]


end of epoch 95
 valid loss: 1.8982336199091328

Model saved.



0it [00:00, ?it/s]

| epoch 96 | 499/717 batches | loss 2.0116395494937898


0it [00:00, ?it/s]


end of epoch 96
 valid loss: 1.89553823922911

Model saved.



0it [00:00, ?it/s]

| epoch 97 | 499/717 batches | loss 1.9994231721162796


0it [00:00, ?it/s]


end of epoch 97
 valid loss: 1.8952989929145383

Model saved.



0it [00:00, ?it/s]

| epoch 98 | 499/717 batches | loss 2.016929509520531


0it [00:00, ?it/s]


end of epoch 98
 valid loss: 1.8891084290319873

Model saved.



0it [00:00, ?it/s]

| epoch 99 | 499/717 batches | loss 1.9773242090940475


0it [00:00, ?it/s]


end of epoch 99
 valid loss: 1.881124675033554

Model saved.



0it [00:00, ?it/s]

| epoch 100 | 499/717 batches | loss 1.975551368713379


0it [00:00, ?it/s]


end of epoch 100
 valid loss: 1.8761680460264605

Model saved.



0it [00:00, ?it/s]

| epoch 101 | 499/717 batches | loss 2.010136302232742


0it [00:00, ?it/s]


end of epoch 101
 valid loss: 1.8758336916085212

Model saved.



0it [00:00, ?it/s]

| epoch 102 | 499/717 batches | loss 1.9704259198904037


0it [00:00, ?it/s]


end of epoch 102
 valid loss: 1.8752174836493307

Model saved.



0it [00:00, ?it/s]

| epoch 103 | 499/717 batches | loss 1.9936713396310806


0it [00:00, ?it/s]


end of epoch 103
 valid loss: 1.868992734099588

Model saved.



0it [00:00, ?it/s]

| epoch 104 | 499/717 batches | loss 1.968431445658207


0it [00:00, ?it/s]


end of epoch 104
 valid loss: 1.8655345935013987

Model saved.



0it [00:00, ?it/s]

| epoch 105 | 499/717 batches | loss 1.9975384923815727


0it [00:00, ?it/s]


end of epoch 105
 valid loss: 1.8603154541023317

Model saved.



0it [00:00, ?it/s]

| epoch 106 | 499/717 batches | loss 1.9805788516402245


0it [00:00, ?it/s]


end of epoch 106
 valid loss: 1.8523856278869413

Model saved.



0it [00:00, ?it/s]

| epoch 107 | 499/717 batches | loss 1.9714271748661996


0it [00:00, ?it/s]


end of epoch 107
 valid loss: 1.8499299715603552

Model saved.



0it [00:00, ?it/s]

| epoch 108 | 499/717 batches | loss 1.9794098434448242


0it [00:00, ?it/s]


end of epoch 108
 valid loss: 1.8465529711496445

Model saved.



0it [00:00, ?it/s]

| epoch 109 | 499/717 batches | loss 1.9519745312929153


0it [00:00, ?it/s]


end of epoch 109
 valid loss: 1.8466672587298578



0it [00:00, ?it/s]

| epoch 110 | 499/717 batches | loss 1.9145823240876199


0it [00:00, ?it/s]


end of epoch 110
 valid loss: 1.8410794148522038

Model saved.



0it [00:00, ?it/s]

| epoch 111 | 499/717 batches | loss 1.9387487552165985


0it [00:00, ?it/s]


end of epoch 111
 valid loss: 1.83778300184396

Model saved.



0it [00:00, ?it/s]

| epoch 112 | 499/717 batches | loss 1.9506376259326934


0it [00:00, ?it/s]


end of epoch 112
 valid loss: 1.8330545218721512

Model saved.



0it [00:00, ?it/s]

| epoch 113 | 499/717 batches | loss 1.9334290142655373


0it [00:00, ?it/s]


end of epoch 113
 valid loss: 1.8339322170903605



0it [00:00, ?it/s]

| epoch 114 | 499/717 batches | loss 1.9708645910024643


0it [00:00, ?it/s]


end of epoch 114
 valid loss: 1.8266538588270065

Model saved.



0it [00:00, ?it/s]

| epoch 115 | 499/717 batches | loss 1.932292530953884


0it [00:00, ?it/s]


end of epoch 115
 valid loss: 1.8242488647660902

Model saved.



0it [00:00, ?it/s]

| epoch 116 | 499/717 batches | loss 1.9345431103110313


0it [00:00, ?it/s]


end of epoch 116
 valid loss: 1.821446467551493

Model saved.



0it [00:00, ?it/s]

| epoch 117 | 499/717 batches | loss 1.927534532904625


0it [00:00, ?it/s]


end of epoch 117
 valid loss: 1.8176838359044445

Model saved.



0it [00:00, ?it/s]

| epoch 118 | 499/717 batches | loss 1.9299883553385735


0it [00:00, ?it/s]


end of epoch 118
 valid loss: 1.8134953966063838

Model saved.



0it [00:00, ?it/s]

| epoch 119 | 499/717 batches | loss 1.8983493846654893


0it [00:00, ?it/s]


end of epoch 119
 valid loss: 1.8102356277165874

Model saved.



0it [00:00, ?it/s]

| epoch 120 | 499/717 batches | loss 1.9063073742985726


0it [00:00, ?it/s]


end of epoch 120
 valid loss: 1.808807912613115

Model saved.



0it [00:00, ?it/s]

| epoch 121 | 499/717 batches | loss 1.9101255034208298


0it [00:00, ?it/s]


end of epoch 121
 valid loss: 1.8069795093709422

Model saved.



0it [00:00, ?it/s]

| epoch 122 | 499/717 batches | loss 1.9375663681030273


0it [00:00, ?it/s]


end of epoch 122
 valid loss: 1.8013109653707473

Model saved.



0it [00:00, ?it/s]

| epoch 123 | 499/717 batches | loss 1.88348724347353


0it [00:00, ?it/s]


end of epoch 123
 valid loss: 1.7965130995838874

Model saved.



0it [00:00, ?it/s]

| epoch 124 | 499/717 batches | loss 1.8954634659290315


0it [00:00, ?it/s]


end of epoch 124
 valid loss: 1.7944134502641615

Model saved.



0it [00:00, ?it/s]

| epoch 125 | 499/717 batches | loss 1.8642473329901694


0it [00:00, ?it/s]


end of epoch 125
 valid loss: 1.7931133186624897

Model saved.



0it [00:00, ?it/s]

| epoch 126 | 499/717 batches | loss 1.8808505790233612


0it [00:00, ?it/s]


end of epoch 126
 valid loss: 1.7917315205258708

Model saved.



0it [00:00, ?it/s]

| epoch 127 | 499/717 batches | loss 1.847413781285286


0it [00:00, ?it/s]


end of epoch 127
 valid loss: 1.7906737947656262

Model saved.



0it [00:00, ?it/s]

| epoch 128 | 499/717 batches | loss 1.8813583680987358


0it [00:00, ?it/s]


end of epoch 128
 valid loss: 1.787565374807004

Model saved.



0it [00:00, ?it/s]

| epoch 129 | 499/717 batches | loss 1.9109870626330376


0it [00:00, ?it/s]


end of epoch 129
 valid loss: 1.7811721767629347

Model saved.



0it [00:00, ?it/s]

| epoch 130 | 499/717 batches | loss 1.8723261020183564


0it [00:00, ?it/s]


end of epoch 130
 valid loss: 1.7796466934584803

Model saved.



0it [00:00, ?it/s]

| epoch 131 | 499/717 batches | loss 1.876574683010578


0it [00:00, ?it/s]


end of epoch 131
 valid loss: 1.7776899820854586

Model saved.



0it [00:00, ?it/s]

| epoch 132 | 499/717 batches | loss 1.894097760617733


0it [00:00, ?it/s]


end of epoch 132
 valid loss: 1.7752366597133298

Model saved.



0it [00:00, ?it/s]

| epoch 133 | 499/717 batches | loss 1.857716426730156


0it [00:00, ?it/s]


end of epoch 133
 valid loss: 1.7723477614983436

Model saved.



0it [00:00, ?it/s]

| epoch 134 | 499/717 batches | loss 1.8391487130522728


0it [00:00, ?it/s]


end of epoch 134
 valid loss: 1.7714327987163299

Model saved.



0it [00:00, ?it/s]

| epoch 135 | 499/717 batches | loss 1.8307079576849938


0it [00:00, ?it/s]


end of epoch 135
 valid loss: 1.7688495175492378

Model saved.



0it [00:00, ?it/s]

| epoch 136 | 499/717 batches | loss 1.8473205199241638


0it [00:00, ?it/s]


end of epoch 136
 valid loss: 1.7665520519498856

Model saved.



0it [00:00, ?it/s]

| epoch 137 | 499/717 batches | loss 1.8766130974292756


0it [00:00, ?it/s]


end of epoch 137
 valid loss: 1.7625419757058542

Model saved.



0it [00:00, ?it/s]

| epoch 138 | 499/717 batches | loss 1.8037163416743278


0it [00:00, ?it/s]


end of epoch 138
 valid loss: 1.7632595000247802



0it [00:00, ?it/s]

| epoch 139 | 499/717 batches | loss 1.8485993456840515


0it [00:00, ?it/s]


end of epoch 139
 valid loss: 1.758153741878848

Model saved.



0it [00:00, ?it/s]

| epoch 140 | 499/717 batches | loss 1.8709124882221222


0it [00:00, ?it/s]


end of epoch 140
 valid loss: 1.758534713618217



0it [00:00, ?it/s]

| epoch 141 | 499/717 batches | loss 1.8438328152000905


0it [00:00, ?it/s]


end of epoch 141
 valid loss: 1.7569647873601606

Model saved.



0it [00:00, ?it/s]

| epoch 142 | 499/717 batches | loss 1.8321500668823718


0it [00:00, ?it/s]


end of epoch 142
 valid loss: 1.756455632707765

Model saved.



0it [00:00, ?it/s]

| epoch 143 | 499/717 batches | loss 1.8003686974048614


0it [00:00, ?it/s]


end of epoch 143
 valid loss: 1.7534471259963127

Model saved.



0it [00:00, ?it/s]

| epoch 144 | 499/717 batches | loss 1.8309126689434052


0it [00:00, ?it/s]


end of epoch 144
 valid loss: 1.7506717001238177

Model saved.



0it [00:00, ?it/s]

| epoch 145 | 499/717 batches | loss 1.8440135110616684


0it [00:00, ?it/s]


end of epoch 145
 valid loss: 1.7483322464169995

Model saved.



0it [00:00, ?it/s]

| epoch 146 | 499/717 batches | loss 1.84939644241333


0it [00:00, ?it/s]


end of epoch 146
 valid loss: 1.7457296168131213

Model saved.



0it [00:00, ?it/s]

| epoch 147 | 499/717 batches | loss 1.8584481683969498


0it [00:00, ?it/s]


end of epoch 147
 valid loss: 1.7450581144902013

Model saved.



0it [00:00, ?it/s]

| epoch 148 | 499/717 batches | loss 1.8360399552583695


0it [00:00, ?it/s]


end of epoch 148
 valid loss: 1.7443848014358552

Model saved.



0it [00:00, ?it/s]

| epoch 149 | 499/717 batches | loss 1.8301794874370099


0it [00:00, ?it/s]


end of epoch 149
 valid loss: 1.7402041278058482

Model saved.



0it [00:00, ?it/s]

| epoch 150 | 499/717 batches | loss 1.8009497335255147


0it [00:00, ?it/s]


end of epoch 150
 valid loss: 1.740283273400799



0it [00:00, ?it/s]

| epoch 151 | 499/717 batches | loss 1.7896383545994758


0it [00:00, ?it/s]


end of epoch 151
 valid loss: 1.7369464587780736

Model saved.



0it [00:00, ?it/s]

| epoch 152 | 499/717 batches | loss 1.786016196012497


0it [00:00, ?it/s]


end of epoch 152
 valid loss: 1.7351312731062212

Model saved.



0it [00:00, ?it/s]

| epoch 153 | 499/717 batches | loss 1.7839507500827312


0it [00:00, ?it/s]


end of epoch 153
 valid loss: 1.735696551059523



0it [00:00, ?it/s]

| epoch 154 | 499/717 batches | loss 1.785247866511345


0it [00:00, ?it/s]


end of epoch 154
 valid loss: 1.7334239542964966

Model saved.



0it [00:00, ?it/s]

| epoch 155 | 499/717 batches | loss 1.7869517303109168


0it [00:00, ?it/s]


end of epoch 155
 valid loss: 1.7313393245060598

Model saved.



0it [00:00, ?it/s]

| epoch 156 | 499/717 batches | loss 1.796659644126892


0it [00:00, ?it/s]


end of epoch 156
 valid loss: 1.7295978351706458

Model saved.



0it [00:00, ?it/s]

| epoch 157 | 499/717 batches | loss 1.7778088682591915


0it [00:00, ?it/s]


end of epoch 157
 valid loss: 1.7278394580127732

Model saved.



0it [00:00, ?it/s]

| epoch 158 | 499/717 batches | loss 1.8026908577680587


0it [00:00, ?it/s]


end of epoch 158
 valid loss: 1.7264209235868146

Model saved.



0it [00:00, ?it/s]

| epoch 159 | 499/717 batches | loss 1.7546847237050534


0it [00:00, ?it/s]


end of epoch 159
 valid loss: 1.7256776477781035

Model saved.



0it [00:00, ?it/s]

| epoch 160 | 499/717 batches | loss 1.7652593422532081


0it [00:00, ?it/s]


end of epoch 160
 valid loss: 1.7221022182174268

Model saved.



0it [00:00, ?it/s]

| epoch 161 | 499/717 batches | loss 1.7592937519848346


0it [00:00, ?it/s]


end of epoch 161
 valid loss: 1.7206647617201651

Model saved.



0it [00:00, ?it/s]

| epoch 162 | 499/717 batches | loss 1.7670279442071914


0it [00:00, ?it/s]


end of epoch 162
 valid loss: 1.7197563165137846

Model saved.



0it [00:00, ?it/s]

| epoch 163 | 499/717 batches | loss 1.7644581147432328


0it [00:00, ?it/s]


end of epoch 163
 valid loss: 1.716648340826073

Model saved.



0it [00:00, ?it/s]

| epoch 164 | 499/717 batches | loss 1.7883823050260543


0it [00:00, ?it/s]


end of epoch 164
 valid loss: 1.7157758850484126

Model saved.



0it [00:00, ?it/s]

| epoch 165 | 499/717 batches | loss 1.7440111623704433


0it [00:00, ?it/s]


end of epoch 165
 valid loss: 1.7117383671143362

Model saved.



0it [00:00, ?it/s]

| epoch 166 | 499/717 batches | loss 1.7820554079115392


0it [00:00, ?it/s]


end of epoch 166
 valid loss: 1.7100490134569906

Model saved.



0it [00:00, ?it/s]

| epoch 167 | 499/717 batches | loss 1.762842532813549


0it [00:00, ?it/s]


end of epoch 167
 valid loss: 1.7096551149602859

Model saved.



0it [00:00, ?it/s]

| epoch 168 | 499/717 batches | loss 1.7706910654008388


0it [00:00, ?it/s]


end of epoch 168
 valid loss: 1.7089474157940956

Model saved.



0it [00:00, ?it/s]

| epoch 169 | 499/717 batches | loss 1.7810946252644062


0it [00:00, ?it/s]


end of epoch 169
 valid loss: 1.7084927138301633

Model saved.



0it [00:00, ?it/s]

| epoch 170 | 499/717 batches | loss 1.7553940610587597


0it [00:00, ?it/s]


end of epoch 170
 valid loss: 1.7073908502296093

Model saved.



0it [00:00, ?it/s]

| epoch 171 | 499/717 batches | loss 1.7544411432147027


0it [00:00, ?it/s]


end of epoch 171
 valid loss: 1.7035692603357377

Model saved.



0it [00:00, ?it/s]

| epoch 172 | 499/717 batches | loss 1.7380939495265484


0it [00:00, ?it/s]


end of epoch 172
 valid loss: 1.7040312605038765



0it [00:00, ?it/s]

| epoch 173 | 499/717 batches | loss 1.7471730592250825


0it [00:00, ?it/s]


end of epoch 173
 valid loss: 1.7025822943016407

Model saved.



0it [00:00, ?it/s]

| epoch 174 | 499/717 batches | loss 1.7733131618201732


0it [00:00, ?it/s]


end of epoch 174
 valid loss: 1.7014708669195253

Model saved.



0it [00:00, ?it/s]

| epoch 175 | 499/717 batches | loss 1.7635215792953969


0it [00:00, ?it/s]


end of epoch 175
 valid loss: 1.69885925239613

Model saved.



0it [00:00, ?it/s]

| epoch 176 | 499/717 batches | loss 1.7501116095483302


0it [00:00, ?it/s]


end of epoch 176
 valid loss: 1.6972387920944922

Model saved.



0it [00:00, ?it/s]

| epoch 177 | 499/717 batches | loss 1.7474195227324962


0it [00:00, ?it/s]


end of epoch 177
 valid loss: 1.6946883541682074

Model saved.



0it [00:00, ?it/s]

| epoch 178 | 499/717 batches | loss 1.750824129909277


0it [00:00, ?it/s]


end of epoch 178
 valid loss: 1.6950084047932779



0it [00:00, ?it/s]

| epoch 179 | 499/717 batches | loss 1.7695018101930617


0it [00:00, ?it/s]


end of epoch 179
 valid loss: 1.694521088513636

Model saved.



0it [00:00, ?it/s]

| epoch 180 | 499/717 batches | loss 1.7514537365138532


0it [00:00, ?it/s]


end of epoch 180
 valid loss: 1.693586170433029

Model saved.



0it [00:00, ?it/s]

| epoch 181 | 499/717 batches | loss 1.7229784047305583


0it [00:00, ?it/s]


end of epoch 181
 valid loss: 1.69276563618933

Model saved.



0it [00:00, ?it/s]

| epoch 182 | 499/717 batches | loss 1.738832904368639


0it [00:00, ?it/s]


end of epoch 182
 valid loss: 1.691337850305342

Model saved.



0it [00:00, ?it/s]

| epoch 183 | 499/717 batches | loss 1.716129981249571


0it [00:00, ?it/s]


end of epoch 183
 valid loss: 1.6902529307671132

Model saved.



0it [00:00, ?it/s]

| epoch 184 | 499/717 batches | loss 1.7609127816855907


0it [00:00, ?it/s]


end of epoch 184
 valid loss: 1.6886789701398341

Model saved.



0it [00:00, ?it/s]

| epoch 185 | 499/717 batches | loss 1.7337189844548702


0it [00:00, ?it/s]


end of epoch 185
 valid loss: 1.6860775678388533

Model saved.



0it [00:00, ?it/s]

| epoch 186 | 499/717 batches | loss 1.7409482606053353


0it [00:00, ?it/s]


end of epoch 186
 valid loss: 1.6853929759273607

Model saved.



0it [00:00, ?it/s]

| epoch 187 | 499/717 batches | loss 1.711442506492138


0it [00:00, ?it/s]


end of epoch 187
 valid loss: 1.683725521088608

Model saved.



0it [00:00, ?it/s]

| epoch 188 | 499/717 batches | loss 1.7338746135532856


0it [00:00, ?it/s]


end of epoch 188
 valid loss: 1.6826121102898353

Model saved.



0it [00:00, ?it/s]

| epoch 189 | 499/717 batches | loss 1.7477533910870553


0it [00:00, ?it/s]


end of epoch 189
 valid loss: 1.6818893668392012

Model saved.



0it [00:00, ?it/s]

| epoch 190 | 499/717 batches | loss 1.7140388762950898


0it [00:00, ?it/s]


end of epoch 190
 valid loss: 1.682934967260207



0it [00:00, ?it/s]

| epoch 191 | 499/717 batches | loss 1.726342868000269


0it [00:00, ?it/s]


end of epoch 191
 valid loss: 1.682394253991304



0it [00:00, ?it/s]

| epoch 192 | 499/717 batches | loss 1.6802928346693515


0it [00:00, ?it/s]


end of epoch 192
 valid loss: 1.680763116286647

Model saved.



0it [00:00, ?it/s]

| epoch 193 | 499/717 batches | loss 1.733759155511856


0it [00:00, ?it/s]


end of epoch 193
 valid loss: 1.677755003134089

Model saved.



0it [00:00, ?it/s]

| epoch 194 | 499/717 batches | loss 1.713069381326437


0it [00:00, ?it/s]


end of epoch 194
 valid loss: 1.6768873600709824

Model saved.



0it [00:00, ?it/s]

| epoch 195 | 499/717 batches | loss 1.7214917239546776


0it [00:00, ?it/s]


end of epoch 195
 valid loss: 1.6761765124336365

Model saved.



0it [00:00, ?it/s]

| epoch 196 | 499/717 batches | loss 1.766743166834116


0it [00:00, ?it/s]


end of epoch 196
 valid loss: 1.6729870497459365

Model saved.



0it [00:00, ?it/s]

| epoch 197 | 499/717 batches | loss 1.7425252770483495


0it [00:00, ?it/s]


end of epoch 197
 valid loss: 1.673574005884509



0it [00:00, ?it/s]

| epoch 198 | 499/717 batches | loss 1.6937491520047188


0it [00:00, ?it/s]


end of epoch 198
 valid loss: 1.6733472582794005



0it [00:00, ?it/s]

| epoch 199 | 499/717 batches | loss 1.7220064327865838


0it [00:00, ?it/s]


end of epoch 199
 valid loss: 1.6725356840798933

Model saved.



0it [00:00, ?it/s]

| epoch 200 | 499/717 batches | loss 1.7201947879940271


0it [00:00, ?it/s]


end of epoch 200
 valid loss: 1.6712973372590156

Model saved.

