# Hardware Requirements

The overall training process takes around **5 hours** using a single P100 GPU (tested on Kaggle). <br>
On Google Colab with a T4 GPU, it will take longer, roughly **8 hours** (Not recommended since Colab might automatically lose connection)

GPU RAM usage:  14GB <br>
Disk: 10 GB

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

# Install Dependencies

In [2]:
!pip install numpy tqdm nltk transformers==4.31.0 peft==0.4.0 accelerate==0.21.0 bitsandbytes==0.41.1 pandas



# Download the PDF-VQA Dataset

**Note:** Although in the provided [**strong baseline juypter notebook**](https://colab.research.google.com/drive/1C4HgX_Yl3vc6u5Wq6duZe0YQBU34ooh2?usp=sharing) there is a way of downloading the files via gdrive API, we noticed that in this way the downloaded validation file "val_dataframe.csv" **DOES NOT** contain the ground-truth answer, which make it impossible to use the validation set for validating the training.
Therefore, we have to download the full dataset from kaggle instead.

Download the full dataset from the Kaggle competition data page: https://www.kaggle.com/competitions/pdfvqa/data, by clicking the button "Download All" in the bottom right. Unzip, and put the downloaded folder "**pdfvqa**" to the same folder as this notebook.

Alternatively, you can try the following command: (This might not work after certain time, since the downloading link was dynamically allocated.)

In [3]:
!curl -o pdfvqa.zip 'https://storage.googleapis.com/kaggle-competitions-data/kaggle-v2/56858/6336428/bundle/archive.zip?GoogleAccessId=web-data@kaggle-161607.iam.gserviceaccount.com&Expires=1696621416&Signature=MD2fz4I6wSiWcaViYmZ31JBfS5A4Qlcsjt557vgdgLydNC5uCRwH21OcX9AzJ4AbYEd9gt9w%2FACLBd%2BUEuMd8%2FLaJ5unJWcgSbERUDLnLCr9p1SE1UCwseAbM2FnocPC3yelMTTjZ8SoMq49rZuB3Fx%2BbCqS6ddTmssDhPZQ5tOrdQD31PDOBXxbmoZxx7JQco6%2FOCMXy%2BoWaeDf9leZfOro5YfNu%2FANo%2FoUDzLQ2kPxL%2B%2F2oGMFpFmoM60zaupqBsaUJz0XIKnO0LfiT7kv6Jd%2BOJLD45iRyXkcHIUYuTCNvTKSjABqelJjDloVRPLZfV1%2BrCNc2C%2BtzaeJSH2sdg%3D%3D&response-content-disposition=attachment%3B+filename%3Dpdfvqa.zip'
!unzip -q pdfvqa.zip -d pdfvqa/

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 2103M  100 2103M    0     0  26.1M      0  0:01:20  0:01:20 --:--:-- 27.3M:19  0:01:02 28.8M20  0:00:37  0:00:43 27.9M0  0:01:16  0:00:04 29.9M


Or, alternatively, use the kaggle API: ([**User authentication of Kaggle API is needed**](https://www.kaggle.com/discussions/general/74235))

In [4]:
# !kaggle competitions download -c pdfvqa

### Here is the structure of the data folder
```
pdfvqa
├── test_dataframe_without_answer.csv
├── test_doc_info.pkl
├── test_images
├── train_dataframe.csv
├── train_doc_info.pkl
├── train_images
├── val_dataframe.csv
├── val_doc_info.pkl
└── val_images
```

# Code for the Model

## Utils

In [5]:
import os
from glob import glob
import torch
import torch.nn as nn
import json
from tqdm import tqdm

def load_model( model_folder ):
    ckpt_list =  glob( model_folder + "/*.pt" )
    if len( ckpt_list ) >0:
        ckpt_list.sort( key = os.path.getmtime )
        ckpt_name = ckpt_list[-1]
        if torch.cuda.is_available():
            ckpt = torch.load( ckpt_name )
        else:
            ckpt = torch.load( ckpt_name,  map_location=torch.device('cpu') )
    else:
        ckpt = None
    return ckpt

def save_model(  module_dicts ,save_name , max_to_keep = 0, overwrite = True ):
    folder_path = os.path.dirname( os.path.abspath( save_name )  )
    if not os.path.exists( folder_path  ):
        os.makedirs( folder_path )

    state_dicts = {}
    for key in module_dicts.keys():
        if isinstance( module_dicts[key], nn.DataParallel ):
            state_dicts[key] = module_dicts[key].module.state_dict()
        elif isinstance( module_dicts[key], nn.Module ):
            state_dicts[key] = module_dicts[key].state_dict()
        else:
            state_dicts[key] = module_dicts[key]

    if os.path.exists( save_name ):
        if overwrite:
            os.remove( save_name )
            torch.save( state_dicts, save_name )
        else:
            print("Warning: checkpoint file already exists!")
            return
    else:
        torch.save( state_dicts, save_name )

    if max_to_keep > 0:
        pt_file_list = glob(folder_path+"/*.pt")
        pt_file_list.sort( key= lambda x: os.path.getmtime(x) )
        for idx in range( len( pt_file_list ) - max_to_keep ):
            os.remove( pt_file_list[idx]  )


def get_lr(optimizer):
    return [param_group['lr'] for param_group in optimizer.param_groups ]

## Datautils

In [6]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
from copy import deepcopy
from tqdm import tqdm

import ast, pickle
import pandas as pd

from nltk.tokenize import RegexpTokenizer
word_tok = RegexpTokenizer(r"\w+")


## the corpus has been preprocessed, so here only lower is needed
## all digits are kept, since sent2vec unigram embedding has digit embedding
## no stemming, no lemmatization
class SentenceTokenizer:
    def __init__(self ):
        pass
    def tokenize(self, sen ):
        return sen.lower()

class Vocab:
    def __init__(self, words, eos_token = "<eos>", pad_token = "<pad>", unk_token = "<unk>" ):
        self.words = words
        self.index_to_word = {}
        self.word_to_index = {}
        for idx in range( len(words) ):
            self.index_to_word[ idx ] = words[idx]
            self.word_to_index[ words[idx] ] = idx
        self.eos_token = eos_token
        self.pad_token = pad_token
        self.unk_token = unk_token
        self.eos_index = self.word_to_index[self.eos_token]
        self.pad_index = self.word_to_index[self.pad_token]

        self.tokenizer = SentenceTokenizer()

    def index2word( self, idx ):
        return self.index_to_word.get( idx, self.unk_token)
    def word2index( self, word ):
        return self.word_to_index.get( word, -1 )
    # The sentence needs to be tokenized
    def sent2seq( self, sent, max_len = None , tokenize = True):
        if tokenize:
            sent = self.tokenizer.tokenize(sent)
        seq = []
        for w in sent.split():
            if w in self.word_to_index:
                seq.append( self.word2index(w) )
        if max_len is not None:
            if len(seq) >= max_len:
                seq = seq[:max_len -1]
                seq.append( self.eos_index )
            else:
                seq.append( self.eos_index )
                seq += [ self.pad_index ] * ( max_len - len(seq) )
        return seq
    def seq2sent( self, seq ):
        sent = []
        for i in seq:
            if i == self.eos_index or i == self.pad_index:
                break
            sent.append( self.index2word(i) )
        return " ".join(sent)

class ExtractionDataset(Dataset):
    def __init__( self,  corpus, vocab , max_seq_len , max_doc_len  ):
        self.vocab = vocab
        self.max_seq_len = max_seq_len
        self.max_doc_len = max_doc_len
        ## corpus is a list
        self.corpus = corpus

    def __len__(self):
        return len(self.corpus)

    def __getitem__( self, idx ):

        doc_data = self.corpus[idx]
        sentences = doc_data["text"]
        valid_sen_idxs = deepcopy( doc_data["indices"] )

        sentences = sentences[:self.max_doc_len]
        num_sentences_in_doc = len( sentences )

        np.random.shuffle( valid_sen_idxs )
        valid_sen_idxs = np.array( valid_sen_idxs )
        valid_sen_idxs = valid_sen_idxs[ valid_sen_idxs < num_sentences_in_doc ]
        valid_sen_idxs = valid_sen_idxs[:self.max_doc_len]

        stop_loss_mask = np.zeros( self.max_doc_len ).astype(np.float32)
        stop_loss_mask[:len(valid_sen_idxs) + 1] = 1.0

        valid_sen_idxs = np.array(valid_sen_idxs.tolist() + [-1] * ( self.max_doc_len - len(valid_sen_idxs)))

        sentences += [""] * ( self.max_doc_len - num_sentences_in_doc )
        doc_mask = np.array(  [ True if sen.strip() == "" else False for sen in  sentences   ]  )

        seqs = [  self.vocab.sent2seq( sen, self.max_seq_len ) for sen in sentences ]
        seqs = np.asarray( seqs )

        return seqs, valid_sen_idxs, doc_mask, stop_loss_mask

def tokenize_sent(sent):
    return " ".join( word_tok.tokenize( sent ) )


def load_dataset_for_inference( doc_info_path, dataframe_path ):
    doc_info = pickle.load(open(doc_info_path, "rb"))
    paper_dict = {}
    for count, pmcid in enumerate(doc_info):
        paper = doc_info[pmcid]
        paper_info = []
        for _, page in paper["pages"].items():
            for obj_id in page.get("ordered_id", []):
                paper_info.append( {
                    "global_id":page["objects"][obj_id]["global_id"],
                    "text":page["objects"][obj_id].get("text", "").replace("\n", " ")
                } )
        paper_dict[pmcid] = paper_info

    query_data = pd.read_csv(dataframe_path)
    if "global_id" in query_data:
        query_data["global_id"] = query_data["global_id"].apply(ast.literal_eval)

    corpus = []
    for pos in range(len( query_data )):
        question = query_data["question"][pos]
        question_type = query_data["question_type"][pos]
        pmcid = str(query_data["pmcid"][pos])
        if "global_id" in query_data:
            answer_global_ids = set(query_data["global_id"][pos])
        else:
            answer_global_ids = None

        corpus.append( {
            "question":question,
            "question_type":question_type,
            "pmcid":pmcid,
            "answer_global_ids":answer_global_ids
        } )

    return corpus, paper_dict

def load_corpus( doc_info_path, dataframe_path  ):

    doc_info = pickle.load(open(doc_info_path, "rb"))
    query_data = pd.read_csv(dataframe_path)
    query_data["global_id"] = query_data["global_id"].apply(ast.literal_eval)

    paper_dict = {}
    for count, pmcid in enumerate(doc_info):
        paper = doc_info[pmcid]
        paper_info = []
        for _, page in paper["pages"].items():
            for obj_id in page.get("ordered_id", []):
                paper_info.append( {
                    "global_id":page["objects"][obj_id]["global_id"],
                    "text":page["objects"][obj_id].get("text", "").replace("\n", " ")
                } )
        paper_dict[pmcid] = paper_info

    corpus = []
    for pos in range(len( query_data )):
        question = query_data["question"][pos]
        question_type = query_data["question_type"][pos]
        pmcid = str(query_data["pmcid"][pos])
        answer_global_ids = set(query_data["global_id"][pos])

        # valid_indices = []
        # sentence_list = [ question_type.replace("_", " ") + " : " + question ]
        # for text_block in paper_dict[pmcid]:
        #     if text_block["global_id"] in answer_global_ids:
        #         valid_indices.append( len( sentence_list ) )
        #     sentence_list.append( text_block["text"] )

        valid_indices = []
        sentence_list = [] #[ question + " : " + question_type.replace("_", " ") ]
        for text_block in paper_dict[pmcid]:
            if text_block["global_id"] in answer_global_ids:
                valid_indices.append( len( sentence_list ) )
            sentence_list.append(  question + " " + question_type.replace("_", " ") + " " + text_block["text"] )

        data_example = {
            "text":[ tokenize_sent(sen) for sen in sentence_list],
            "indices":valid_indices
        }

        corpus.append( data_example )
    return corpus

## Model

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from torch.distributions import Categorical

class AddMask( nn.Module ):
    def __init__( self, pad_index ):
        super().__init__()
        self.pad_index = pad_index
    def forward( self, x):
        # here x is a batch of input sequences (not embeddings) with the shape of [ batch_size, seq_len]
        mask = x == self.pad_index
        return mask

class PositionalEncoding( nn.Module ):
    def __init__(self,  embed_dim, max_seq_len = 512  ):
        super().__init__()
        self.embed_dim = embed_dim
        self.max_seq_len = max_seq_len
        pe = torch.zeros( 1, max_seq_len,  embed_dim )
        for pos in range( max_seq_len ):
            for i in range( 0, embed_dim, 2 ):
                pe[ 0, pos, i ] = math.sin( pos / ( 10000 ** ( i/embed_dim ) )  )
                if i+1 < embed_dim:
                    pe[ 0, pos, i+1 ] = math.cos( pos / ( 10000** ( i/embed_dim ) ) )
        self.register_buffer( "pe", pe )
        ## register_buffer can register some variables that can be saved and loaded by state_dict, but not trainable since not accessible by model.parameters()
    def forward( self, x ):
        return x + self.pe[ :, : x.size(1), :]

class MultiHeadAttention( nn.Module ):
    def __init__(self, embed_dim, num_heads ):
        super().__init__()
        dim_per_head = int( embed_dim/num_heads )

        self.ln_q = nn.Linear( embed_dim, num_heads * dim_per_head )
        self.ln_k = nn.Linear( embed_dim, num_heads * dim_per_head )
        self.ln_v = nn.Linear( embed_dim, num_heads * dim_per_head )

        self.ln_out = nn.Linear( num_heads * dim_per_head, embed_dim )

        self.num_heads = num_heads
        self.dim_per_head = dim_per_head

    def forward( self, q,k,v, mask = None):
        q = self.ln_q( q )
        k = self.ln_k( k )
        v = self.ln_v( v )

        q = q.view( q.size(0), q.size(1),  self.num_heads, self.dim_per_head  ).transpose( 1,2 )
        k = k.view( k.size(0), k.size(1),  self.num_heads, self.dim_per_head  ).transpose( 1,2 )
        v = v.view( v.size(0), v.size(1),  self.num_heads, self.dim_per_head  ).transpose( 1,2 )

        a = self.scaled_dot_product_attention( q,k, mask )
        new_v = a.matmul(v)
        new_v = new_v.transpose( 1,2 ).contiguous()
        new_v = new_v.view( new_v.size(0), new_v.size(1), -1 )
        new_v = self.ln_out(new_v)
        return new_v

    def scaled_dot_product_attention( self, q, k, mask = None ):
        ## note the here q and k have converted into multi-head mode
        ## q's shape is [ Batchsize, num_heads, seq_len_q, dim_per_head ]
        ## k's shape is [ Batchsize, num_heads, seq_len_k, dim_per_head ]
        # scaled dot product
        a = q.matmul( k.transpose( 2,3 ) )/ math.sqrt( q.size(-1) )
        # apply mask (either padding mask or seqeunce mask)
        if mask is not None:
            a = a.masked_fill( mask.unsqueeze(1).unsqueeze(1) , -1e9 )
        # apply softmax, to get the likelihood as attention matrix
        a = F.softmax( a, dim=-1 )
        return a

class FeedForward( nn.Module ):
    def __init__( self, embed_dim, hidden_dim ):
        super().__init__()
        self.ln1 = nn.Linear( embed_dim, hidden_dim )
        self.ln2 = nn.Linear( hidden_dim, embed_dim )
    def forward(  self, x):
        net = F.relu(self.ln1(x))
        out = self.ln2(net)
        return out

class TransformerDecoderLayer( nn.Module ):
    def __init__(self, embed_dim, num_heads, hidden_dim ):
        super().__init__()
        self.masked_mha = MultiHeadAttention(  embed_dim, num_heads )
        self.norm1 = nn.LayerNorm( embed_dim )
        self.mha = MultiHeadAttention( embed_dim, num_heads )
        self.norm2 = nn.LayerNorm( embed_dim )
        self.feed_forward = FeedForward( embed_dim, hidden_dim )
        self.norm3 = nn.LayerNorm( embed_dim )
    def forward(self, encoder_output, x, src_mask, trg_mask , dropout_rate = 0. ):
        short_cut = x
        net = F.dropout(self.masked_mha( x,x,x, trg_mask ), p = dropout_rate)
        net = self.norm1( short_cut + net )
        short_cut = net
        net = F.dropout(self.mha( net, encoder_output, encoder_output, src_mask ), p = dropout_rate)
        net = self.norm2( short_cut + net )
        short_cut = net
        net = F.dropout(self.feed_forward( net ), p = dropout_rate)
        net = self.norm3( short_cut + net )
        return net

class MultiHeadPoolingLayer( nn.Module ):
    def __init__( self, embed_dim, num_heads  ):
        super().__init__()
        self.num_heads = num_heads
        self.dim_per_head = int( embed_dim/num_heads )
        self.ln_attention_score = nn.Linear( embed_dim, num_heads )
        self.ln_value = nn.Linear( embed_dim,  num_heads * self.dim_per_head )
        self.ln_out = nn.Linear( num_heads * self.dim_per_head , embed_dim )
    def forward(self, input_embedding , mask=None):
        a = self.ln_attention_score( input_embedding )
        v = self.ln_value( input_embedding )

        a = a.view( a.size(0), a.size(1), self.num_heads, 1 ).transpose(1,2)
        v = v.view( v.size(0), v.size(1),  self.num_heads, self.dim_per_head  ).transpose(1,2)
        a = a.transpose(2,3)
        if mask is not None:
            a = a.masked_fill( mask.unsqueeze(1).unsqueeze(1) , -1e9 )
        a = F.softmax(a , dim = -1 )

        new_v = a.matmul(v)
        new_v = new_v.transpose( 1,2 ).contiguous()
        new_v = new_v.view( new_v.size(0), new_v.size(1) ,-1 ).squeeze(1)
        new_v = self.ln_out( new_v )
        return new_v

class LocalSentenceEncoder( nn.Module ):
    def __init__( self, vocab_size, pad_index, embed_dim, num_heads , hidden_dim, pretrained_word_embedding ):
        super().__init__()
        self.addmask = AddMask( pad_index )

        self.rnn = nn.LSTM(  embed_dim, embed_dim, 2, batch_first = True, bidirectional = True)
        self.mh_pool = MultiHeadPoolingLayer( 2*embed_dim, num_heads )
        self.norm_out = nn.LayerNorm( 2*embed_dim )
        self.ln_out = nn.Linear( 2*embed_dim, embed_dim )

        if pretrained_word_embedding is not None:
            ## make sure the pad embedding is 0
            pretrained_word_embedding[pad_index] = 0
            self.register_buffer( "word_embedding", torch.from_numpy( pretrained_word_embedding ) )
        else:
            self.register_buffer( "word_embedding", torch.randn( vocab_size, embed_dim ) )

    """
    input_seq 's shape:  batch_size x seq_len
    """
    def forward( self, input_seq, dropout_rate = 0. ):
        mask = self.addmask( input_seq )
        ## batch_size x seq_len x embed_dim
        net = self.word_embedding[ input_seq ]
        net, _ = self.rnn( net )
        net =  self.ln_out(F.relu(self.norm_out(self.mh_pool( net, mask ))))
        return net

class GlobalContextEncoder(nn.Module):
    def __init__(self, embed_dim,  num_heads, hidden_dim ):
        super().__init__()
        self.rnn = nn.LSTM(  embed_dim, embed_dim, 2, batch_first = True, bidirectional = True)
        self.norm_out = nn.LayerNorm( 2*embed_dim )
        self.ln_out = nn.Linear( 2*embed_dim, embed_dim )

    def forward(self, sen_embed, doc_mask, dropout_rate = 0.):
        net, _ = self.rnn( sen_embed )
        net = self.ln_out(F.relu( self.norm_out(net) ) )
        return net

class ExtractionContextDecoder( nn.Module ):
    def __init__( self, embed_dim,  num_heads, hidden_dim, num_dec_layers ):
        super().__init__()
        self.pos_encode = PositionalEncoding( embed_dim)
        self.layer_list = nn.ModuleList( [  TransformerDecoderLayer( embed_dim, num_heads, hidden_dim ) for _ in range(num_dec_layers) ] )
    ## remaining_mask: set all unextracted sen indices as True
    ## extraction_mask: set all extracted sen indices as True
    def forward( self, sen_embed, remaining_mask, extraction_mask, dropout_rate = 0. ):
        # sen_embed = self.pos_encode( sen_embed )

        net = sen_embed
        for layer in self.layer_list:
            #  encoder_output, x,  src_mask, trg_mask , dropout_rate = 0.
            net = layer( sen_embed, net, remaining_mask, extraction_mask, dropout_rate )
        return net

class Extractor( nn.Module ):
    def __init__( self, embed_dim, num_heads ):
        super().__init__()
        self.norm_input = nn.LayerNorm( 3*embed_dim  )

        self.ln_hidden1 = nn.Linear(  3*embed_dim, 2*embed_dim  )
        self.norm_hidden1 = nn.LayerNorm( 2*embed_dim  )

        self.ln_hidden2 = nn.Linear(  2*embed_dim, embed_dim  )
        self.norm_hidden2 = nn.LayerNorm( embed_dim  )

        self.ln_out = nn.Linear(  embed_dim, 1 )

        self.mh_pool = MultiHeadPoolingLayer( embed_dim, num_heads )
        self.norm_pool = nn.LayerNorm( embed_dim  )
        self.ln_stop = nn.Linear(  embed_dim, 1 )

    def forward( self, sen_embed, relevance_embed, redundancy_embed , extraction_mask, dropout_rate = 0. ):
        if redundancy_embed is None:
            redundancy_embed = torch.zeros_like( sen_embed )
        net = self.norm_input( F.dropout( torch.cat( [ sen_embed, relevance_embed, redundancy_embed ], dim = 2 ) , p = dropout_rate  )  )
        net = F.relu( self.norm_hidden1( F.dropout( self.ln_hidden1( net ) , p = dropout_rate  )   ))
        hidden_net = F.relu( self.norm_hidden2( F.dropout( self.ln_hidden2( net)  , p = dropout_rate  )  ))

        sen_logits = self.ln_out( hidden_net ).squeeze(2)

        net = F.relu( self.norm_pool(  F.dropout( self.mh_pool( hidden_net, extraction_mask) , p = dropout_rate  )  ))
        stop_logits = self.ln_stop( net ).squeeze(1)

        return sen_logits, stop_logits

class MemSumPDFVQA:
    def __init__( self, model_path, vocabulary_path, gpu = None , embed_dim=200, num_heads=8, hidden_dim = 1024, num_ehe_layers = 3,  max_seq_len =200, max_doc_len = 200  ):
        with open( vocabulary_path , "rb" ) as f:
            words = pickle.load(f)
        self.vocab = Vocab( words )
        vocab_size = len(words)
        self.local_sentence_encoder = LocalSentenceEncoder( vocab_size, self.vocab.pad_index, embed_dim,num_heads,hidden_dim,None )
        self.global_context_encoder = GlobalContextEncoder( embed_dim, num_heads, hidden_dim )
        self.extraction_context_decoder = ExtractionContextDecoder( embed_dim, num_heads, hidden_dim, num_ehe_layers )
        self.extractor = Extractor( embed_dim, num_heads )
        ckpt = torch.load( model_path, map_location = "cpu" )
        self.local_sentence_encoder.load_state_dict( ckpt["local_sentence_encoder"] )
        self.global_context_encoder.load_state_dict( ckpt["global_context_encoder"] )
        self.extraction_context_decoder.load_state_dict( ckpt["extraction_context_decoder"] )
        self.extractor.load_state_dict(ckpt["extractor"])

        self.device =  torch.device( "cuda:%d"%(gpu) if gpu is not None and torch.cuda.is_available() else "cpu"  )
        self.local_sentence_encoder.to(self.device)
        self.global_context_encoder.to(self.device)
        self.extraction_context_decoder.to(self.device)
        self.extractor.to(self.device)

        self.sentence_tokenizer = SentenceTokenizer()
        self.max_seq_len = max_seq_len
        self.max_doc_len = max_doc_len

        self.word_tok = RegexpTokenizer(r"\w+")

    def tokenize_sent(self, sent):
        return " ".join( self.word_tok.tokenize( sent ) )


    def inference( self, question, question_type, doc_content, p_stop_thres = 0.2, max_extracted_sentences_per_document = 4 ):
        """doc_content is a list of { "global_id":global_id, "text":text }"""

        ## convert to sequence
        seqs = []
        doc_mask = []

        document = []
        for text_block in doc_content:
            document.append( self.tokenize_sent( question + " " + question_type.replace("_", " ") + " " + text_block["text"] ) )
        document = document[:self.max_doc_len]
        document = document + [""] * ( self.max_doc_len -  len(document) )

        doc_mask.append(  [ 1 if sen.strip() == "" else 0 for sen in  document   ] )
        document_sequences = []
        for sen in document:
            seq = self.vocab.sent2seq( sen, self.max_seq_len )
            document_sequences.append(seq)
        seqs.append(document_sequences)
        seqs = np.asarray(seqs)
        doc_mask = np.asarray(doc_mask) == 1
        seqs = torch.from_numpy(seqs).to(self.device)
        doc_mask = torch.from_numpy(doc_mask).to(self.device)

        num_documents = seqs.size(0)
        with torch.no_grad():
            num_sentences = seqs.size(1)
            sen_embed  = self.local_sentence_encoder( seqs.view(-1, seqs.size(2) )  )
            sen_embed = sen_embed.view( -1, num_sentences, sen_embed.size(1) )
            relevance_embed = self.global_context_encoder( sen_embed, doc_mask  )

            num_documents = seqs.size(0)
            doc_mask = doc_mask.detach().cpu().numpy()
            seqs = seqs.detach().cpu().numpy()

            extracted_sentences_positions = []

            for doc_i in range(num_documents):
                current_doc_mask = doc_mask[doc_i:doc_i+1]
                current_remaining_mask_np = np.ones_like(current_doc_mask ).astype(bool) | current_doc_mask
                current_extraction_mask_np = np.zeros_like(current_doc_mask).astype(bool) | current_doc_mask

                current_sen_embed = sen_embed[doc_i:doc_i+1]
                current_relevance_embed = relevance_embed[ doc_i:doc_i+1 ]
                current_redundancy_embed = None

                current_hyps = []

                for step in range( max_extracted_sentences_per_document+1 ) :
                    current_extraction_mask = torch.from_numpy( current_extraction_mask_np ).to(self.device)
                    current_remaining_mask = torch.from_numpy( current_remaining_mask_np ).to(self.device)
                    if step > 0:
                        current_redundancy_embed = self.extraction_context_decoder( current_sen_embed, current_remaining_mask, current_extraction_mask  )
                    sen_logits, stop_logits = self.extractor( current_sen_embed, current_relevance_embed, current_redundancy_embed , current_extraction_mask  )
                    sen_logits = sen_logits.masked_fill( current_extraction_mask, -1e9 )
                    p_stop = stop_logits.sigmoid()

                    stop = p_stop.item() > p_stop_thres
                    sen_i = sen_logits.argmax(dim=1)[0].item()

                    if stop or step == max_extracted_sentences_per_document or sen_i >= len(doc_content):
                        extracted_sentences_positions.append( current_hyps )
                        break
                    else:
                        current_hyps.append(sen_i)
                        current_extraction_mask_np[0, sen_i] = True
                        current_remaining_mask_np[0, sen_i] = False

            pred_indices = extracted_sentences_positions[0]
            if len(pred_indices) == 0:
                return [ -1 ]
            else:
                return [ doc_content[idx]["global_id"] for idx in pred_indices ]

# Training

## Training Preparation

In [8]:
import os
import pickle
from tqdm import tqdm
import time
import os, sys
import copy
import re
from huggingface_hub import snapshot_download
from transformers import set_seed
## download the pretrained glove word embedding (200 dimension)
snapshot_download('nianlong/memsum-word-embedding', local_dir = "model/word_embedding" )

def update_moving_average( m_ema, m, decay ):
    with torch.no_grad():
        param_dict_m_ema =  m_ema.module.parameters()  if isinstance(  m_ema, nn.DataParallel ) else m_ema.parameters() 
        param_dict_m =  m.module.parameters()  if isinstance( m , nn.DataParallel ) else  m.parameters() 
        for param_m_ema, param_m in zip( param_dict_m_ema, param_dict_m ):
            param_m_ema.copy_( decay * param_m_ema + (1-decay) *  param_m )
            

def LOG( info, log_out_file, end="\n" ):
    with open( log_out_file, "a" ) as f:
        f.write( info + end )

def train_iteration(batch):
    seqs, valid_sen_idxs, doc_mask, stop_loss_mask = [_.to(device) for _ in batch]
    num_documents = seqs.size(0)
    num_sentences = seqs.size(1)

    local_sen_embed = local_sentence_encoder( seqs.view(-1, seqs.size(2) ), dropout_rate )
    local_sen_embed = local_sen_embed.view( -1, num_sentences, local_sen_embed.size(1) )
    global_context_embed = global_context_encoder( local_sen_embed, doc_mask , dropout_rate )

    doc_mask_np = doc_mask.detach().cpu().numpy()
    remaining_mask_np = np.ones_like( doc_mask_np ).astype( bool ) | doc_mask_np
    extraction_mask_np = np.zeros_like( doc_mask_np ).astype( bool ) | doc_mask_np

    sen_logits_list = []
    sen_indices_list = []
    stop_logits_list = []
    stop_loss_mask_list = []

    extraction_context_embed = None
    for step in range( min(valid_sen_idxs.shape[1], max_extracted_sentences_per_document ) ):
        remaining_mask = torch.from_numpy( remaining_mask_np ).to(device)
        extraction_mask = torch.from_numpy( extraction_mask_np ).to(device)
        if step > 0:
            extraction_context_embed = extraction_context_decoder( local_sen_embed, remaining_mask, extraction_mask, dropout_rate )
        sen_logits, stop_logits = extractor( local_sen_embed, global_context_embed, extraction_context_embed, extraction_mask , dropout_rate )
        sen_logits = sen_logits.masked_fill( extraction_mask, -1e9 )
        sen_indices = valid_sen_idxs[:, step]

        sen_logits_list.append( sen_logits.unsqueeze(1) )
        sen_indices_list.append( sen_indices.unsqueeze(1) )
        stop_logits_list.append( stop_logits.unsqueeze(1) )
        stop_loss_mask_list.append( stop_loss_mask[:, step:step+1] )

        for doc_i in range( num_documents ):
            sen_i = sen_indices[ doc_i ].item()
            if sen_i != -1:
                remaining_mask_np[doc_i,sen_i] = False
                extraction_mask_np[doc_i,sen_i] = True

    sen_logits_list = torch.cat(sen_logits_list, dim = 1 ).view(-1, num_sentences)
    sen_indices_list = torch.cat(sen_indices_list, dim = 1 ).view(-1)
    stop_logits_list = torch.cat(stop_logits_list, dim = 1 ).view(-1)
    stop_loss_mask_list = torch.cat( stop_loss_mask_list, dim = 1 ).view(-1)

    ce_loss = ce_loss_criterion( sen_logits_list, sen_indices_list )
    ce_loss = ce_loss.sum() / ( (sen_indices_list != -1).sum() + 1e-9  )

    stop_loss = bce_loss_criterion( stop_logits_list, (sen_indices_list == -1).to(torch.float32)  )
    stop_loss = (stop_loss*stop_loss_mask_list).sum() / ( stop_loss_mask_list.sum() + 1e-9  )

    loss = ce_loss + stop_loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item()

def calculate_exact_match_ratio(true_labels, predicted_labels):
    exact_match = 0
    for i in range(len(true_labels)):
        if true_labels[i] == predicted_labels[i]:
            exact_match += 1
    return exact_match / len(true_labels)

@torch.no_grad()
def validate(debug = False):
    predicted_indices = []
    label_indices = []

    predicted_sentences = []
    label_sentences = []

    for batch in val_data_loader:
        seqs, valid_sen_idxs, doc_mask, stop_loss_mask = [_.to(device) for _ in batch]

        num_sentences = seqs.size(1)
        local_sen_embed  = local_sentence_encoder_ema( seqs.view(-1, seqs.size(2) )  )
        local_sen_embed = local_sen_embed.view( -1, num_sentences, local_sen_embed.size(1) )
        global_context_embed = global_context_encoder_ema( local_sen_embed, doc_mask  )

        num_documents = seqs.size(0)
        doc_mask = doc_mask.detach().cpu().numpy()
        remaining_mask_np = np.ones_like( doc_mask ).astype( bool ) | doc_mask
        extraction_mask_np = np.zeros_like( doc_mask ).astype( bool ) | doc_mask

        done_list = []
        extraction_context_embed = None

        for step in range(max_extracted_sentences_per_document):
            remaining_mask = torch.from_numpy( remaining_mask_np ).to(device)
            extraction_mask = torch.from_numpy( extraction_mask_np ).to(device)
            if step > 0:
                extraction_context_embed = extraction_context_decoder_ema( local_sen_embed, remaining_mask, extraction_mask )
            sen_logits, stop_logits = extractor_ema( local_sen_embed, global_context_embed, extraction_context_embed , extraction_mask  )
            sen_logits = sen_logits.masked_fill( extraction_mask, -1e9 )
            p_stop = stop_logits.sigmoid()
            stop_action = p_stop > p_stop_thres

            done = stop_action | torch.all(extraction_mask, dim = 1)
            if len(done_list) > 0:
                done = torch.logical_or(done_list[-1], done)
            if torch.all( done ):
                break

            sen_indices = torch.argmax(sen_logits, dim =1)
            done_list.append(done)

            for doc_i in range( num_documents ):
                if not done[doc_i]:
                    sen_i = sen_indices[ doc_i ].item()
                    remaining_mask_np[doc_i,sen_i] = False
                    extraction_mask_np[doc_i,sen_i] = True

        predicted_indices += [ set(np.argwhere( remaining_mask_np[doc_i] == False )[:,0].tolist()) for doc_i in range( remaining_mask_np.shape[0] ) ]
        label_indices += [ set(valid_sen_idxs[doc_i][valid_sen_idxs[doc_i]!=-1].detach().cpu().numpy()) for doc_i in range( valid_sen_idxs.size(0) ) ]

        predicted_sentences += [ [ vocab.seq2sent( seqs[doc_i][sen_idx].cpu().numpy() ) for sen_idx in np.argwhere( remaining_mask_np[doc_i] == False )[:,0]] for doc_i in range( remaining_mask_np.shape[0] )   ]
        label_sentences += [ [ vocab.seq2sent( seqs[doc_i][sen_idx].cpu().numpy() ) for sen_idx in valid_sen_idxs[doc_i][valid_sen_idxs[doc_i]!=-1].detach().cpu().numpy()] for doc_i in range( valid_sen_idxs.size(0) )   ]

    score = calculate_exact_match_ratio( label_indices, predicted_indices )
    if not debug:
        return score
    else:
        return score, label_indices, predicted_indices, predicted_sentences, label_sentences

  from .autonotebook import tqdm as notebook_tqdm
Fetching 4 files: 100%|█████████████████████████████████████████████████| 4/4 [00:00<00:00,  9.24it/s]


## Start Training
The overall training process takes around **5 hours** using a single P100 GPU (available on Kaggle).

On Google Colab with a T4 GPU, it will take longer, roughly **9 hours** (Not recommended since Colab might automatically lose connection)

In [9]:
seed = 2023
data_folder = "pdfvqa/"
training_doc_info_path = data_folder + "/train_doc_info.pkl"
training_dataframe_path = data_folder + "/train_dataframe.csv"
validation_doc_info_path = data_folder + "/val_doc_info.pkl"
validation_dataframe_path = data_folder + "/val_dataframe.csv"
test_doc_info_path = data_folder + "/test_doc_info.pkl"
test_dataframe_path = data_folder + "/test_dataframe_without_answer.csv"
model_folder = "model/memsum_dqa/"
log_folder = "log/memsum_dqa/"
vocabulary_file_name = "model/word_embedding/vocabulary_200dim.pkl"
pretrained_unigram_embeddings_file_name = "model/word_embedding/unigram_embeddings_200dim.pkl"
num_heads = 8
hidden_dim = 1024
num_ehe_layers = 3
max_seq_len = 200
max_doc_len = 200
num_of_epochs = 50
print_every = 100
save_every = 200
validate_every = 200
restore_old_checkpoint = 1
learning_rate = 1e-4
warmup_step = 500
weight_decay = 0
dropout_rate = 0.2
n_device = 1
batch_size_per_device = 8
max_extracted_sentences_per_document = 4
moving_average_decay = 0.99
p_stop_thres = 0.2

set_seed(seed)

os.makedirs(log_folder, exist_ok=True)
os.makedirs(model_folder, exist_ok=True)

train_log_out_file = log_folder + "/train.log"
val_log_out_file = log_folder + "/val.log"

training_corpus = load_corpus( training_doc_info_path, training_dataframe_path  )
validation_corpus = load_corpus( validation_doc_info_path, validation_dataframe_path )

with open( vocabulary_file_name, "rb") as f:
    words = pickle.load(f)
with open(pretrained_unigram_embeddings_file_name, "rb") as f:
    pretrained_embedding = pickle.load(f)
vocab = Vocab(words)
vocab_size, embed_dim = pretrained_embedding.shape

train_dataset = ExtractionDataset(  training_corpus,  vocab , max_seq_len,  max_doc_len)
train_data_loader = DataLoader( train_dataset, batch_size=batch_size_per_device * n_device , shuffle=True, num_workers=4, worker_init_fn = lambda x:[np.random.seed( seed+x ), torch.manual_seed( seed + x) ], drop_last= True,  pin_memory= True )
total_number_of_samples = train_dataset.__len__()
val_dataset = ExtractionDataset( validation_corpus, vocab, max_seq_len, max_doc_len )
val_data_loader = DataLoader( val_dataset, batch_size=batch_size_per_device * n_device , shuffle=False, num_workers=4, worker_init_fn = lambda x:[np.random.seed( seed+x ), torch.manual_seed( seed + x) ], drop_last= False,  pin_memory= True)

local_sentence_encoder = LocalSentenceEncoder( vocab_size,vocab.pad_index, embed_dim,num_heads,hidden_dim,pretrained_embedding )
global_context_encoder = GlobalContextEncoder( embed_dim, num_heads, hidden_dim )
extraction_context_decoder = ExtractionContextDecoder( embed_dim, num_heads, hidden_dim, num_ehe_layers )
extractor = Extractor( embed_dim, num_heads )

# restore most recent checkpoint
if restore_old_checkpoint:
    ckpt = load_model( model_folder )
else:
    ckpt = None

if ckpt is not None:
    local_sentence_encoder.load_state_dict( ckpt["local_sentence_encoder"] )
    global_context_encoder.load_state_dict( ckpt["global_context_encoder"] )
    extraction_context_decoder.load_state_dict( ckpt["extraction_context_decoder"] )
    extractor.load_state_dict( ckpt["extractor"] )
    LOG("model restored!", train_log_out_file)
    print("model restored!")

gpu_list = np.arange(n_device).tolist()
device = torch.device(  "cuda:%d"%( gpu_list[0] ) if torch.cuda.is_available() else "cpu" )

local_sentence_encoder_ema = copy.deepcopy( local_sentence_encoder ).to(device)
global_context_encoder_ema = copy.deepcopy( global_context_encoder  ).to(device)
extraction_context_decoder_ema = copy.deepcopy( extraction_context_decoder ).to(device)
extractor_ema = copy.deepcopy( extractor ).to(device)

local_sentence_encoder.to(device)
global_context_encoder.to(device)
extraction_context_decoder.to(device)
extractor.to(device)

if device.type == "cuda" and n_device > 1:
    local_sentence_encoder = nn.DataParallel( local_sentence_encoder, gpu_list )
    global_context_encoder = nn.DataParallel( global_context_encoder, gpu_list )
    extraction_context_decoder = nn.DataParallel( extraction_context_decoder, gpu_list )
    extractor = nn.DataParallel( extractor, gpu_list )

    local_sentence_encoder_ema = nn.DataParallel( local_sentence_encoder_ema, gpu_list )
    global_context_encoder_ema = nn.DataParallel( global_context_encoder_ema, gpu_list )
    extraction_context_decoder_ema = nn.DataParallel( extraction_context_decoder_ema, gpu_list )
    extractor_ema = nn.DataParallel( extractor_ema, gpu_list )

    model_parameters = [ par for par in local_sentence_encoder.module.parameters() if par.requires_grad  ]  + \
                    [ par for par in global_context_encoder.module.parameters() if par.requires_grad  ]   + \
                    [ par for par in extraction_context_decoder.module.parameters() if par.requires_grad  ]  + \
                    [ par for par in extractor.module.parameters() if par.requires_grad  ]
else:
    model_parameters =  [ par for par in local_sentence_encoder.parameters() if par.requires_grad  ]  + \
                    [ par for par in global_context_encoder.parameters() if par.requires_grad  ]   + \
                    [ par for par in extraction_context_decoder.parameters() if par.requires_grad  ]  + \
                    [ par for par in extractor.parameters() if par.requires_grad  ]

optimizer = torch.optim.Adam( model_parameters , lr= learning_rate , weight_decay = weight_decay )

if ckpt is not None:
    try:
        optimizer.load_state_dict( ckpt["optimizer"] )
        LOG("optimizer restored!", train_log_out_file)
        print("optimizer restored!")
    except:
        pass

current_epoch = 0
current_batch = 0

if ckpt is not None:
    current_batch = ckpt["current_batch"]
    current_epoch = int( current_batch * batch_size_per_device * n_device / total_number_of_samples )
    LOG("current_batch restored!", train_log_out_file)
    print("current_batch restored!")

ce_loss_criterion = nn.CrossEntropyLoss( reduction="none", ignore_index=-1 )
bce_loss_criterion = nn.BCEWithLogitsLoss( reduction="none" )

max_val_score = None
best_ckpt_step = None

for epoch in range( current_epoch, num_of_epochs ):
    running_loss = 0

    for count, batch in tqdm(enumerate(train_data_loader)):
        loss = train_iteration(batch)
        running_loss += loss

        update_moving_average(  local_sentence_encoder_ema,  local_sentence_encoder, moving_average_decay)
        update_moving_average(  global_context_encoder_ema ,  global_context_encoder, moving_average_decay)
        update_moving_average(  extraction_context_decoder_ema,  extraction_context_decoder, moving_average_decay)
        update_moving_average(  extractor_ema,  extractor, moving_average_decay)

        current_batch +=1
        if current_batch % print_every == 0:
            current_learning_rate = get_lr( optimizer )[0]
            LOG( "[current_batch: %05d] loss: %.3f, learning rate: %f"%( current_batch, running_loss/print_every,  current_learning_rate  ), train_log_out_file )
            print( "[current_batch: %05d] loss: %.3f, learning rate: %f"%( current_batch, running_loss/print_every, current_learning_rate  ) )
            os.system( "nvidia-smi > %s/gpu_usage.log"%( log_folder ) )
            running_loss = 0

        if ( validate_every != 0 and  current_batch % validate_every == 0 ) or count == len(train_data_loader) - 1:
            print("Starting validation ...")
            LOG("Starting validation ...", train_log_out_file)
            # validation
            val_score = validate()

            ## only save the checkpoint that performs best on validation set so far
            if max_val_score is None or val_score > max_val_score:
                max_val_score = val_score
                best_ckpt_step = current_batch
                save_model(  {
                    "current_batch":current_batch,
                    "local_sentence_encoder": local_sentence_encoder_ema,
                    "global_context_encoder": global_context_encoder_ema,
                    "extraction_context_decoder":extraction_context_decoder_ema,
                    "extractor":extractor_ema,
                    "optimizer": optimizer.state_dict()
                    } , model_folder+"/model_batch_%d.pt"%(current_batch), max_to_keep = 1 )

            if max_val_score is not None:
                print("[current_batch: %05d] val: %.4f, [best_batch: %05d] best val: %.4f"%(current_batch, val_score, best_ckpt_step, max_val_score))
                LOG("[current_batch: %05d] val: %.4f, [best_batch: %05d] best val: %.4f"%(current_batch, val_score, best_ckpt_step, max_val_score), val_log_out_file)
            else:
                print("[current_batch: %05d] val: %.4f"%(current_batch, val_score ))
                LOG("[current_batch: %05d] val: %.4f"%(current_batch, val_score ), val_log_out_file)


99it [00:26,  3.84it/s]

[current_batch: 00100] loss: 4.110, learning rate: 0.000100


199it [00:52,  3.79it/s]

[current_batch: 00200] loss: 3.368, learning rate: 0.000100
Starting validation ...


200it [00:59,  2.30s/it]

[current_batch: 00200] val: 0.1343, [best_batch: 00200] best val: 0.1343


299it [01:25,  3.85it/s]

[current_batch: 00300] loss: 3.119, learning rate: 0.000100


399it [01:51,  3.81it/s]

[current_batch: 00400] loss: 3.006, learning rate: 0.000100
Starting validation ...


400it [01:59,  2.32s/it]

[current_batch: 00400] val: 0.1583, [best_batch: 00400] best val: 0.1583


492it [02:23,  3.84it/s]

Starting validation ...


493it [02:30,  3.29it/s]

[current_batch: 00493] val: 0.2410, [best_batch: 00493] best val: 0.2410



6it [00:01,  3.72it/s]

[current_batch: 00500] loss: 0.202, learning rate: 0.000100


106it [00:28,  3.81it/s]

[current_batch: 00600] loss: 2.771, learning rate: 0.000100
Starting validation ...


107it [00:35,  2.31s/it]

[current_batch: 00600] val: 0.2719, [best_batch: 00600] best val: 0.2719


206it [01:01,  3.79it/s]

[current_batch: 00700] loss: 2.782, learning rate: 0.000100


306it [01:27,  3.80it/s]

[current_batch: 00800] loss: 2.770, learning rate: 0.000100
Starting validation ...


307it [01:35,  2.32s/it]

[current_batch: 00800] val: 0.2823, [best_batch: 00800] best val: 0.2823


406it [02:01,  3.78it/s]

[current_batch: 00900] loss: 2.623, learning rate: 0.000100


492it [02:24,  3.77it/s]

Starting validation ...


493it [02:30,  3.27it/s]

[current_batch: 00986] val: 0.3029, [best_batch: 00986] best val: 0.3029



13it [00:03,  3.80it/s]

[current_batch: 01000] loss: 0.357, learning rate: 0.000100
Starting validation ...


14it [00:10,  2.20s/it]

[current_batch: 01000] val: 0.3029, [best_batch: 00986] best val: 0.3029


113it [00:36,  3.78it/s]

[current_batch: 01100] loss: 2.588, learning rate: 0.000100


213it [01:02,  3.80it/s]

[current_batch: 01200] loss: 2.608, learning rate: 0.000100
Starting validation ...


214it [01:10,  2.32s/it]

[current_batch: 01200] val: 0.3133, [best_batch: 01200] best val: 0.3133


313it [01:36,  3.81it/s]

[current_batch: 01300] loss: 2.588, learning rate: 0.000100


413it [02:02,  3.83it/s]

[current_batch: 01400] loss: 2.533, learning rate: 0.000100
Starting validation ...


414it [02:09,  2.32s/it]

[current_batch: 01400] val: 0.3167, [best_batch: 01400] best val: 0.3167


492it [02:30,  3.82it/s]

Starting validation ...


493it [02:36,  3.15it/s]

[current_batch: 01479] val: 0.3167, [best_batch: 01400] best val: 0.3167



20it [00:05,  3.77it/s]

[current_batch: 01500] loss: 0.546, learning rate: 0.000100


120it [00:32,  3.81it/s]

[current_batch: 01600] loss: 2.481, learning rate: 0.000100
Starting validation ...


121it [00:38,  2.21s/it]

[current_batch: 01600] val: 0.3150, [best_batch: 01400] best val: 0.3167


220it [01:04,  3.78it/s]

[current_batch: 01700] loss: 2.419, learning rate: 0.000100


320it [01:31,  3.79it/s]

[current_batch: 01800] loss: 2.533, learning rate: 0.000100
Starting validation ...


321it [01:38,  2.33s/it]

[current_batch: 01800] val: 0.3201, [best_batch: 01800] best val: 0.3201


420it [02:04,  3.81it/s]

[current_batch: 01900] loss: 2.312, learning rate: 0.000100


492it [02:23,  3.82it/s]

Starting validation ...


493it [02:31,  3.26it/s]

[current_batch: 01972] val: 0.3253, [best_batch: 01972] best val: 0.3253



27it [00:07,  3.79it/s]

[current_batch: 02000] loss: 0.695, learning rate: 0.000100
Starting validation ...


28it [00:14,  2.36s/it]

[current_batch: 02000] val: 0.3287, [best_batch: 02000] best val: 0.3287


127it [00:40,  3.78it/s]

[current_batch: 02100] loss: 2.467, learning rate: 0.000100


227it [01:07,  3.79it/s]

[current_batch: 02200] loss: 2.330, learning rate: 0.000100
Starting validation ...


228it [01:14,  2.36s/it]

[current_batch: 02200] val: 0.3339, [best_batch: 02200] best val: 0.3339


327it [01:40,  3.81it/s]

[current_batch: 02300] loss: 2.356, learning rate: 0.000100


427it [02:06,  3.84it/s]

[current_batch: 02400] loss: 2.362, learning rate: 0.000100
Starting validation ...


428it [02:13,  2.22s/it]

[current_batch: 02400] val: 0.3322, [best_batch: 02200] best val: 0.3339


492it [02:30,  3.80it/s]

Starting validation ...


493it [02:37,  3.14it/s]

[current_batch: 02465] val: 0.3236, [best_batch: 02200] best val: 0.3339



34it [00:09,  3.84it/s]

[current_batch: 02500] loss: 0.869, learning rate: 0.000100


134it [00:35,  3.81it/s]

[current_batch: 02600] loss: 2.374, learning rate: 0.000100
Starting validation ...


135it [00:43,  2.38s/it]

[current_batch: 02600] val: 0.3408, [best_batch: 02600] best val: 0.3408


234it [01:09,  3.81it/s]

[current_batch: 02700] loss: 2.327, learning rate: 0.000100


334it [01:35,  3.82it/s]

[current_batch: 02800] loss: 2.349, learning rate: 0.000100
Starting validation ...


335it [01:42,  2.24s/it]

[current_batch: 02800] val: 0.3408, [best_batch: 02600] best val: 0.3408


434it [02:08,  3.85it/s]

[current_batch: 02900] loss: 2.350, learning rate: 0.000100


492it [02:24,  3.87it/s]

Starting validation ...


493it [02:30,  3.27it/s]

[current_batch: 02958] val: 0.3305, [best_batch: 02600] best val: 0.3408



41it [00:10,  3.87it/s]

[current_batch: 03000] loss: 0.972, learning rate: 0.000100
Starting validation ...


42it [00:17,  2.23s/it]

[current_batch: 03000] val: 0.3305, [best_batch: 02600] best val: 0.3408


141it [00:43,  3.85it/s]

[current_batch: 03100] loss: 2.271, learning rate: 0.000100


241it [01:10,  3.84it/s]

[current_batch: 03200] loss: 2.379, learning rate: 0.000100
Starting validation ...


242it [01:17,  2.39s/it]

[current_batch: 03200] val: 0.3511, [best_batch: 03200] best val: 0.3511


341it [01:43,  3.77it/s]

[current_batch: 03300] loss: 2.307, learning rate: 0.000100


441it [02:10,  3.78it/s]

[current_batch: 03400] loss: 2.187, learning rate: 0.000100
Starting validation ...


442it [02:17,  2.24s/it]

[current_batch: 03400] val: 0.3494, [best_batch: 03200] best val: 0.3511


492it [02:31,  3.76it/s]

Starting validation ...


493it [02:38,  3.12it/s]

[current_batch: 03451] val: 0.3546, [best_batch: 03451] best val: 0.3546



48it [00:12,  3.81it/s]

[current_batch: 03500] loss: 1.103, learning rate: 0.000100


148it [00:39,  3.80it/s]

[current_batch: 03600] loss: 2.197, learning rate: 0.000100
Starting validation ...


149it [00:46,  2.24s/it]

[current_batch: 03600] val: 0.3356, [best_batch: 03451] best val: 0.3546


248it [01:12,  3.80it/s]

[current_batch: 03700] loss: 2.145, learning rate: 0.000100


348it [01:39,  3.75it/s]

[current_batch: 03800] loss: 2.247, learning rate: 0.000100
Starting validation ...


349it [01:45,  2.25s/it]

[current_batch: 03800] val: 0.3442, [best_batch: 03451] best val: 0.3546


448it [02:12,  3.82it/s]

[current_batch: 03900] loss: 2.116, learning rate: 0.000100


492it [02:24,  3.84it/s]

Starting validation ...


493it [02:30,  3.27it/s]

[current_batch: 03944] val: 0.3408, [best_batch: 03451] best val: 0.3546



55it [00:14,  3.83it/s]

[current_batch: 04000] loss: 1.199, learning rate: 0.000100
Starting validation ...


56it [00:21,  2.26s/it]

[current_batch: 04000] val: 0.3356, [best_batch: 03451] best val: 0.3546


155it [00:47,  3.81it/s]

[current_batch: 04100] loss: 2.118, learning rate: 0.000100


255it [01:14,  3.78it/s]

[current_batch: 04200] loss: 2.213, learning rate: 0.000100
Starting validation ...


256it [01:21,  2.40s/it]

[current_batch: 04200] val: 0.3580, [best_batch: 04200] best val: 0.3580


355it [01:47,  3.79it/s]

[current_batch: 04300] loss: 2.167, learning rate: 0.000100


455it [02:14,  3.82it/s]

[current_batch: 04400] loss: 2.198, learning rate: 0.000100
Starting validation ...


456it [02:21,  2.25s/it]

[current_batch: 04400] val: 0.3494, [best_batch: 04200] best val: 0.3580


492it [02:30,  3.84it/s]

Starting validation ...


493it [02:37,  3.13it/s]

[current_batch: 04437] val: 0.3511, [best_batch: 04200] best val: 0.3580



62it [00:16,  3.81it/s]

[current_batch: 04500] loss: 1.335, learning rate: 0.000100


162it [00:42,  3.84it/s]

[current_batch: 04600] loss: 2.141, learning rate: 0.000100
Starting validation ...


163it [00:49,  2.26s/it]

[current_batch: 04600] val: 0.3563, [best_batch: 04200] best val: 0.3580


262it [01:15,  3.78it/s]

[current_batch: 04700] loss: 2.118, learning rate: 0.000100


362it [01:42,  3.80it/s]

[current_batch: 04800] loss: 2.058, learning rate: 0.000100
Starting validation ...


363it [01:49,  2.39s/it]

[current_batch: 04800] val: 0.3632, [best_batch: 04800] best val: 0.3632


462it [02:15,  3.82it/s]

[current_batch: 04900] loss: 2.083, learning rate: 0.000100


492it [02:23,  3.82it/s]

Starting validation ...


493it [02:30,  3.28it/s]

[current_batch: 04930] val: 0.3632, [best_batch: 04800] best val: 0.3632



69it [00:18,  3.79it/s]

[current_batch: 05000] loss: 1.455, learning rate: 0.000100
Starting validation ...


70it [00:25,  2.25s/it]

[current_batch: 05000] val: 0.3597, [best_batch: 04800] best val: 0.3632


169it [00:51,  3.81it/s]

[current_batch: 05100] loss: 1.996, learning rate: 0.000100


269it [01:17,  3.84it/s]

[current_batch: 05200] loss: 2.004, learning rate: 0.000100
Starting validation ...


270it [01:25,  2.41s/it]

[current_batch: 05200] val: 0.3649, [best_batch: 05200] best val: 0.3649


369it [01:51,  3.82it/s]

[current_batch: 05300] loss: 2.098, learning rate: 0.000100


469it [02:18,  3.78it/s]

[current_batch: 05400] loss: 1.973, learning rate: 0.000100
Starting validation ...


470it [02:25,  2.39s/it]

[current_batch: 05400] val: 0.3718, [best_batch: 05400] best val: 0.3718


492it [02:31,  3.80it/s]

Starting validation ...


493it [02:37,  3.12it/s]

[current_batch: 05423] val: 0.3666, [best_batch: 05400] best val: 0.3718



76it [00:20,  3.85it/s]

[current_batch: 05500] loss: 1.577, learning rate: 0.000100


176it [00:46,  3.84it/s]

[current_batch: 05600] loss: 1.941, learning rate: 0.000100
Starting validation ...


177it [00:53,  2.25s/it]

[current_batch: 05600] val: 0.3683, [best_batch: 05400] best val: 0.3718


276it [01:19,  3.78it/s]

[current_batch: 05700] loss: 1.977, learning rate: 0.000100


376it [01:46,  3.84it/s]

[current_batch: 05800] loss: 1.923, learning rate: 0.000100
Starting validation ...


377it [01:53,  2.40s/it]

[current_batch: 05800] val: 0.3752, [best_batch: 05800] best val: 0.3752


476it [02:19,  3.72it/s]

[current_batch: 05900] loss: 2.048, learning rate: 0.000100


492it [02:24,  3.75it/s]

Starting validation ...


493it [02:31,  3.26it/s]

[current_batch: 05916] val: 0.3769, [best_batch: 05916] best val: 0.3769



83it [00:21,  3.77it/s]

[current_batch: 06000] loss: 1.558, learning rate: 0.000100
Starting validation ...


84it [00:28,  2.25s/it]

[current_batch: 06000] val: 0.3735, [best_batch: 05916] best val: 0.3769


183it [00:54,  3.84it/s]

[current_batch: 06100] loss: 1.902, learning rate: 0.000100


283it [01:21,  3.83it/s]

[current_batch: 06200] loss: 1.944, learning rate: 0.000100
Starting validation ...


284it [01:28,  2.25s/it]

[current_batch: 06200] val: 0.3649, [best_batch: 05916] best val: 0.3769


383it [01:54,  3.81it/s]

[current_batch: 06300] loss: 1.908, learning rate: 0.000100


483it [02:20,  3.82it/s]

[current_batch: 06400] loss: 1.977, learning rate: 0.000100
Starting validation ...


484it [02:27,  2.26s/it]

[current_batch: 06400] val: 0.3718, [best_batch: 05916] best val: 0.3769


492it [02:29,  2.67it/s]

Starting validation ...


493it [02:36,  3.15it/s]

[current_batch: 06409] val: 0.3666, [best_batch: 05916] best val: 0.3769



90it [00:23,  3.78it/s]

[current_batch: 06500] loss: 1.679, learning rate: 0.000100


190it [00:50,  3.83it/s]

[current_batch: 06600] loss: 1.750, learning rate: 0.000100
Starting validation ...


191it [00:57,  2.27s/it]

[current_batch: 06600] val: 0.3580, [best_batch: 05916] best val: 0.3769


290it [01:23,  3.79it/s]

[current_batch: 06700] loss: 1.899, learning rate: 0.000100


390it [01:50,  3.78it/s]

[current_batch: 06800] loss: 1.951, learning rate: 0.000100
Starting validation ...


391it [01:57,  2.26s/it]

[current_batch: 06800] val: 0.3614, [best_batch: 05916] best val: 0.3769


490it [02:23,  3.81it/s]

[current_batch: 06900] loss: 1.856, learning rate: 0.000100


492it [02:24,  3.02it/s]

Starting validation ...


493it [02:31,  3.26it/s]

[current_batch: 06902] val: 0.3701, [best_batch: 05916] best val: 0.3769



97it [00:25,  3.75it/s]

[current_batch: 07000] loss: 1.726, learning rate: 0.000100
Starting validation ...


98it [00:32,  2.26s/it]

[current_batch: 07000] val: 0.3666, [best_batch: 05916] best val: 0.3769


197it [00:58,  3.80it/s]

[current_batch: 07100] loss: 1.823, learning rate: 0.000100


297it [01:25,  3.82it/s]

[current_batch: 07200] loss: 1.806, learning rate: 0.000100
Starting validation ...


298it [01:32,  2.26s/it]

[current_batch: 07200] val: 0.3546, [best_batch: 05916] best val: 0.3769


397it [01:58,  3.79it/s]

[current_batch: 07300] loss: 1.762, learning rate: 0.000100


492it [02:24,  3.84it/s]

Starting validation ...


493it [02:30,  3.27it/s]

[current_batch: 07395] val: 0.3769, [best_batch: 05916] best val: 0.3769



4it [00:01,  3.64it/s]

[current_batch: 07400] loss: 0.092, learning rate: 0.000100
Starting validation ...


5it [00:08,  2.68s/it]

[current_batch: 07400] val: 0.3718, [best_batch: 05916] best val: 0.3769


104it [00:34,  3.75it/s]

[current_batch: 07500] loss: 1.724, learning rate: 0.000100


204it [01:00,  3.81it/s]

[current_batch: 07600] loss: 1.705, learning rate: 0.000100
Starting validation ...


205it [01:07,  2.28s/it]

[current_batch: 07600] val: 0.3769, [best_batch: 05916] best val: 0.3769


304it [01:33,  3.83it/s]

[current_batch: 07700] loss: 1.819, learning rate: 0.000100


404it [02:00,  3.77it/s]

[current_batch: 07800] loss: 1.837, learning rate: 0.000100
Starting validation ...


405it [02:07,  2.27s/it]

[current_batch: 07800] val: 0.3580, [best_batch: 05916] best val: 0.3769


492it [02:30,  3.87it/s]

Starting validation ...


493it [02:36,  3.14it/s]

[current_batch: 07888] val: 0.3632, [best_batch: 05916] best val: 0.3769



11it [00:03,  3.79it/s]

[current_batch: 07900] loss: 0.241, learning rate: 0.000100


111it [00:29,  3.86it/s]

[current_batch: 08000] loss: 1.742, learning rate: 0.000100
Starting validation ...


112it [00:36,  2.26s/it]

[current_batch: 08000] val: 0.3769, [best_batch: 05916] best val: 0.3769


211it [01:02,  3.82it/s]

[current_batch: 08100] loss: 1.725, learning rate: 0.000100


311it [01:29,  3.80it/s]

[current_batch: 08200] loss: 1.714, learning rate: 0.000100
Starting validation ...


312it [01:36,  2.27s/it]

[current_batch: 08200] val: 0.3683, [best_batch: 05916] best val: 0.3769


411it [02:02,  3.74it/s]

[current_batch: 08300] loss: 1.739, learning rate: 0.000100


492it [02:23,  3.83it/s]

Starting validation ...


493it [02:31,  3.26it/s]

[current_batch: 08381] val: 0.3873, [best_batch: 08381] best val: 0.3873



18it [00:04,  3.81it/s]

[current_batch: 08400] loss: 0.323, learning rate: 0.000100
Starting validation ...


19it [00:12,  2.46s/it]

[current_batch: 08400] val: 0.3890, [best_batch: 08400] best val: 0.3890


118it [00:38,  3.83it/s]

[current_batch: 08500] loss: 1.667, learning rate: 0.000100


218it [01:04,  3.80it/s]

[current_batch: 08600] loss: 1.641, learning rate: 0.000100
Starting validation ...


219it [01:11,  2.27s/it]

[current_batch: 08600] val: 0.3821, [best_batch: 08400] best val: 0.3890


318it [01:38,  3.82it/s]

[current_batch: 08700] loss: 1.638, learning rate: 0.000100


418it [02:04,  3.76it/s]

[current_batch: 08800] loss: 1.710, learning rate: 0.000100
Starting validation ...


419it [02:11,  2.28s/it]

[current_batch: 08800] val: 0.3666, [best_batch: 08400] best val: 0.3890


492it [02:30,  3.81it/s]

Starting validation ...


493it [02:37,  3.13it/s]

[current_batch: 08874] val: 0.3683, [best_batch: 08400] best val: 0.3890



25it [00:06,  3.85it/s]

[current_batch: 08900] loss: 0.352, learning rate: 0.000100


125it [00:33,  3.80it/s]

[current_batch: 09000] loss: 1.566, learning rate: 0.000100
Starting validation ...


126it [00:40,  2.27s/it]

[current_batch: 09000] val: 0.3855, [best_batch: 08400] best val: 0.3890


225it [01:06,  3.84it/s]

[current_batch: 09100] loss: 1.632, learning rate: 0.000100


325it [01:32,  3.83it/s]

[current_batch: 09200] loss: 1.624, learning rate: 0.000100
Starting validation ...


326it [01:39,  2.28s/it]

[current_batch: 09200] val: 0.3838, [best_batch: 08400] best val: 0.3890


425it [02:05,  3.76it/s]

[current_batch: 09300] loss: 1.630, learning rate: 0.000100


492it [02:23,  3.81it/s]

Starting validation ...


493it [02:30,  3.28it/s]

[current_batch: 09367] val: 0.3735, [best_batch: 08400] best val: 0.3890



32it [00:08,  3.81it/s]

[current_batch: 09400] loss: 0.504, learning rate: 0.000100
Starting validation ...


33it [00:15,  2.28s/it]

[current_batch: 09400] val: 0.3718, [best_batch: 08400] best val: 0.3890


132it [00:41,  3.82it/s]

[current_batch: 09500] loss: 1.533, learning rate: 0.000100


232it [01:08,  3.79it/s]

[current_batch: 09600] loss: 1.490, learning rate: 0.000100
Starting validation ...


233it [01:15,  2.28s/it]

[current_batch: 09600] val: 0.3838, [best_batch: 08400] best val: 0.3890


332it [01:41,  3.84it/s]

[current_batch: 09700] loss: 1.641, learning rate: 0.000100


432it [02:07,  3.81it/s]

[current_batch: 09800] loss: 1.595, learning rate: 0.000100
Starting validation ...


433it [02:14,  2.27s/it]

[current_batch: 09800] val: 0.3890, [best_batch: 08400] best val: 0.3890


492it [02:29,  3.81it/s]

Starting validation ...


493it [02:36,  3.15it/s]

[current_batch: 09860] val: 0.3873, [best_batch: 08400] best val: 0.3890



39it [00:10,  3.75it/s]

[current_batch: 09900] loss: 0.618, learning rate: 0.000100


139it [00:36,  3.83it/s]

[current_batch: 10000] loss: 1.485, learning rate: 0.000100
Starting validation ...


140it [00:43,  2.28s/it]

[current_batch: 10000] val: 0.3787, [best_batch: 08400] best val: 0.3890


239it [01:10,  3.82it/s]

[current_batch: 10100] loss: 1.514, learning rate: 0.000100


339it [01:36,  3.81it/s]

[current_batch: 10200] loss: 1.491, learning rate: 0.000100
Starting validation ...


340it [01:44,  2.41s/it]

[current_batch: 10200] val: 0.3924, [best_batch: 10200] best val: 0.3924


439it [02:10,  3.79it/s]

[current_batch: 10300] loss: 1.488, learning rate: 0.000100


492it [02:24,  3.78it/s]

Starting validation ...


493it [02:31,  3.25it/s]

[current_batch: 10353] val: 0.3941, [best_batch: 10353] best val: 0.3941



46it [00:12,  3.78it/s]

[current_batch: 10400] loss: 0.663, learning rate: 0.000100
Starting validation ...


47it [00:19,  2.28s/it]

[current_batch: 10400] val: 0.3941, [best_batch: 10353] best val: 0.3941


146it [00:45,  3.76it/s]

[current_batch: 10500] loss: 1.512, learning rate: 0.000100


246it [01:12,  3.81it/s]

[current_batch: 10600] loss: 1.451, learning rate: 0.000100
Starting validation ...


247it [01:19,  2.43s/it]

[current_batch: 10600] val: 0.4114, [best_batch: 10600] best val: 0.4114


346it [01:45,  3.80it/s]

[current_batch: 10700] loss: 1.506, learning rate: 0.000100


446it [02:12,  3.82it/s]

[current_batch: 10800] loss: 1.446, learning rate: 0.000100
Starting validation ...


447it [02:19,  2.26s/it]

[current_batch: 10800] val: 0.3924, [best_batch: 10600] best val: 0.4114


492it [02:30,  3.85it/s]

Starting validation ...


493it [02:37,  3.13it/s]

[current_batch: 10846] val: 0.3855, [best_batch: 10600] best val: 0.4114



53it [00:14,  3.84it/s]

[current_batch: 10900] loss: 0.736, learning rate: 0.000100


153it [00:40,  3.76it/s]

[current_batch: 11000] loss: 1.458, learning rate: 0.000100
Starting validation ...


154it [00:48,  2.43s/it]

[current_batch: 11000] val: 0.4131, [best_batch: 11000] best val: 0.4131


253it [01:14,  3.80it/s]

[current_batch: 11100] loss: 1.500, learning rate: 0.000100


353it [01:40,  3.80it/s]

[current_batch: 11200] loss: 1.400, learning rate: 0.000100
Starting validation ...


354it [01:47,  2.28s/it]

[current_batch: 11200] val: 0.3941, [best_batch: 11000] best val: 0.4131


453it [02:13,  3.88it/s]

[current_batch: 11300] loss: 1.408, learning rate: 0.000100


492it [02:24,  3.80it/s]

Starting validation ...


493it [02:30,  3.26it/s]

[current_batch: 11339] val: 0.4028, [best_batch: 11000] best val: 0.4131



60it [00:15,  3.79it/s]

[current_batch: 11400] loss: 0.808, learning rate: 0.000100
Starting validation ...


61it [00:22,  2.27s/it]

[current_batch: 11400] val: 0.3907, [best_batch: 11000] best val: 0.4131


160it [00:48,  3.83it/s]

[current_batch: 11500] loss: 1.378, learning rate: 0.000100


260it [01:15,  3.87it/s]

[current_batch: 11600] loss: 1.366, learning rate: 0.000100
Starting validation ...


261it [01:22,  2.27s/it]

[current_batch: 11600] val: 0.3821, [best_batch: 11000] best val: 0.4131


360it [01:48,  3.82it/s]

[current_batch: 11700] loss: 1.362, learning rate: 0.000100


460it [02:14,  3.81it/s]

[current_batch: 11800] loss: 1.407, learning rate: 0.000100
Starting validation ...


461it [02:21,  2.28s/it]

[current_batch: 11800] val: 0.3838, [best_batch: 11000] best val: 0.4131


492it [02:30,  3.73it/s]

Starting validation ...


493it [02:36,  3.14it/s]

[current_batch: 11832] val: 0.3787, [best_batch: 11000] best val: 0.4131



67it [00:17,  3.82it/s]

[current_batch: 11900] loss: 0.881, learning rate: 0.000100


167it [00:44,  3.82it/s]

[current_batch: 12000] loss: 1.324, learning rate: 0.000100
Starting validation ...


168it [00:51,  2.27s/it]

[current_batch: 12000] val: 0.4028, [best_batch: 11000] best val: 0.4131


267it [01:17,  3.81it/s]

[current_batch: 12100] loss: 1.270, learning rate: 0.000100


367it [01:43,  3.81it/s]

[current_batch: 12200] loss: 1.373, learning rate: 0.000100
Starting validation ...


368it [01:50,  2.28s/it]

[current_batch: 12200] val: 0.3941, [best_batch: 11000] best val: 0.4131


467it [02:16,  3.85it/s]

[current_batch: 12300] loss: 1.311, learning rate: 0.000100


492it [02:23,  3.85it/s]

Starting validation ...


493it [02:30,  3.28it/s]

[current_batch: 12325] val: 0.3924, [best_batch: 11000] best val: 0.4131



74it [00:19,  3.76it/s]

[current_batch: 12400] loss: 0.934, learning rate: 0.000100
Starting validation ...


75it [00:26,  2.29s/it]

[current_batch: 12400] val: 0.3907, [best_batch: 11000] best val: 0.4131


174it [00:52,  3.82it/s]

[current_batch: 12500] loss: 1.239, learning rate: 0.000100


274it [01:19,  3.78it/s]

[current_batch: 12600] loss: 1.337, learning rate: 0.000100
Starting validation ...


275it [01:26,  2.28s/it]

[current_batch: 12600] val: 0.3804, [best_batch: 11000] best val: 0.4131


374it [01:52,  3.79it/s]

[current_batch: 12700] loss: 1.253, learning rate: 0.000100


474it [02:18,  3.81it/s]

[current_batch: 12800] loss: 1.276, learning rate: 0.000100
Starting validation ...


475it [02:25,  2.28s/it]

[current_batch: 12800] val: 0.3804, [best_batch: 11000] best val: 0.4131


492it [02:30,  3.72it/s]

Starting validation ...


493it [02:37,  3.14it/s]

[current_batch: 12818] val: 0.3769, [best_batch: 11000] best val: 0.4131



81it [00:21,  3.83it/s]

[current_batch: 12900] loss: 0.940, learning rate: 0.000100


181it [00:48,  3.83it/s]

[current_batch: 13000] loss: 1.234, learning rate: 0.000100
Starting validation ...


182it [00:55,  2.28s/it]

[current_batch: 13000] val: 0.3838, [best_batch: 11000] best val: 0.4131


281it [01:20,  3.84it/s]

[current_batch: 13100] loss: 1.207, learning rate: 0.000100


381it [01:47,  3.78it/s]

[current_batch: 13200] loss: 1.207, learning rate: 0.000100
Starting validation ...


382it [01:54,  2.27s/it]

[current_batch: 13200] val: 0.3907, [best_batch: 11000] best val: 0.4131


481it [02:20,  3.81it/s]

[current_batch: 13300] loss: 1.268, learning rate: 0.000100


492it [02:23,  3.82it/s]

Starting validation ...


493it [02:30,  3.27it/s]

[current_batch: 13311] val: 0.3873, [best_batch: 11000] best val: 0.4131



88it [00:23,  3.75it/s]

[current_batch: 13400] loss: 0.990, learning rate: 0.000100
Starting validation ...


89it [00:30,  2.29s/it]

[current_batch: 13400] val: 0.3821, [best_batch: 11000] best val: 0.4131


188it [00:56,  3.77it/s]

[current_batch: 13500] loss: 1.132, learning rate: 0.000100


288it [01:22,  3.85it/s]

[current_batch: 13600] loss: 1.151, learning rate: 0.000100
Starting validation ...


289it [01:29,  2.27s/it]

[current_batch: 13600] val: 0.3993, [best_batch: 11000] best val: 0.4131


388it [01:55,  3.84it/s]

[current_batch: 13700] loss: 1.190, learning rate: 0.000100


488it [02:22,  3.80it/s]

[current_batch: 13800] loss: 1.237, learning rate: 0.000100
Starting validation ...


489it [02:29,  2.29s/it]

[current_batch: 13800] val: 0.3959, [best_batch: 11000] best val: 0.4131


492it [02:29,  1.05it/s]

Starting validation ...


493it [02:36,  3.15it/s]

[current_batch: 13804] val: 0.3959, [best_batch: 11000] best val: 0.4131



95it [00:25,  3.76it/s]

[current_batch: 13900] loss: 1.038, learning rate: 0.000100


195it [00:51,  3.82it/s]

[current_batch: 14000] loss: 1.086, learning rate: 0.000100
Starting validation ...


196it [00:58,  2.28s/it]

[current_batch: 14000] val: 0.3976, [best_batch: 11000] best val: 0.4131


295it [01:24,  3.77it/s]

[current_batch: 14100] loss: 1.150, learning rate: 0.000100


395it [01:51,  3.77it/s]

[current_batch: 14200] loss: 1.197, learning rate: 0.000100
Starting validation ...


396it [01:58,  2.29s/it]

[current_batch: 14200] val: 0.3821, [best_batch: 11000] best val: 0.4131


492it [02:24,  3.76it/s]

Starting validation ...


493it [02:30,  3.27it/s]

[current_batch: 14297] val: 0.3873, [best_batch: 11000] best val: 0.4131



2it [00:00,  3.28it/s]

[current_batch: 14300] loss: 0.028, learning rate: 0.000100


102it [00:27,  3.76it/s]

[current_batch: 14400] loss: 1.041, learning rate: 0.000100
Starting validation ...


103it [00:34,  2.29s/it]

[current_batch: 14400] val: 0.3838, [best_batch: 11000] best val: 0.4131


202it [01:00,  3.84it/s]

[current_batch: 14500] loss: 1.136, learning rate: 0.000100


302it [01:26,  3.77it/s]

[current_batch: 14600] loss: 1.104, learning rate: 0.000100
Starting validation ...


303it [01:33,  2.29s/it]

[current_batch: 14600] val: 0.3907, [best_batch: 11000] best val: 0.4131


402it [01:59,  3.76it/s]

[current_batch: 14700] loss: 1.090, learning rate: 0.000100


492it [02:23,  3.78it/s]

Starting validation ...


493it [02:30,  3.27it/s]

[current_batch: 14790] val: 0.3873, [best_batch: 11000] best val: 0.4131



9it [00:02,  3.80it/s]

[current_batch: 14800] loss: 0.086, learning rate: 0.000100
Starting validation ...


10it [00:09,  2.35s/it]

[current_batch: 14800] val: 0.3855, [best_batch: 11000] best val: 0.4131


109it [00:35,  3.81it/s]

[current_batch: 14900] loss: 0.960, learning rate: 0.000100


209it [01:01,  3.81it/s]

[current_batch: 15000] loss: 1.039, learning rate: 0.000100
Starting validation ...


210it [01:08,  2.28s/it]

[current_batch: 15000] val: 0.3907, [best_batch: 11000] best val: 0.4131


309it [01:35,  3.76it/s]

[current_batch: 15100] loss: 1.037, learning rate: 0.000100


409it [02:01,  3.75it/s]

[current_batch: 15200] loss: 1.130, learning rate: 0.000100
Starting validation ...


410it [02:09,  2.29s/it]

[current_batch: 15200] val: 0.3787, [best_batch: 11000] best val: 0.4131


492it [02:30,  3.81it/s]

Starting validation ...


493it [02:37,  3.13it/s]

[current_batch: 15283] val: 0.3752, [best_batch: 11000] best val: 0.4131



16it [00:04,  3.76it/s]

[current_batch: 15300] loss: 0.151, learning rate: 0.000100


116it [00:31,  3.74it/s]

[current_batch: 15400] loss: 0.937, learning rate: 0.000100
Starting validation ...


117it [00:38,  2.29s/it]

[current_batch: 15400] val: 0.3804, [best_batch: 11000] best val: 0.4131


216it [01:04,  3.79it/s]

[current_batch: 15500] loss: 1.009, learning rate: 0.000100


316it [01:30,  3.77it/s]

[current_batch: 15600] loss: 1.032, learning rate: 0.000100
Starting validation ...


317it [01:37,  2.29s/it]

[current_batch: 15600] val: 0.3683, [best_batch: 11000] best val: 0.4131


416it [02:03,  3.81it/s]

[current_batch: 15700] loss: 1.021, learning rate: 0.000100


492it [02:23,  3.82it/s]

Starting validation ...


493it [02:30,  3.27it/s]

[current_batch: 15776] val: 0.3804, [best_batch: 11000] best val: 0.4131



23it [00:06,  3.80it/s]

[current_batch: 15800] loss: 0.223, learning rate: 0.000100
Starting validation ...


24it [00:13,  2.29s/it]

[current_batch: 15800] val: 0.3873, [best_batch: 11000] best val: 0.4131


123it [00:39,  3.76it/s]

[current_batch: 15900] loss: 1.001, learning rate: 0.000100


223it [01:05,  3.77it/s]

[current_batch: 16000] loss: 0.877, learning rate: 0.000100
Starting validation ...


224it [01:12,  2.29s/it]

[current_batch: 16000] val: 0.3855, [best_batch: 11000] best val: 0.4131


323it [01:38,  3.83it/s]

[current_batch: 16100] loss: 0.912, learning rate: 0.000100


423it [02:05,  3.80it/s]

[current_batch: 16200] loss: 1.064, learning rate: 0.000100
Starting validation ...


424it [02:12,  2.27s/it]

[current_batch: 16200] val: 0.3821, [best_batch: 11000] best val: 0.4131


492it [02:30,  3.81it/s]

Starting validation ...


493it [02:36,  3.14it/s]

[current_batch: 16269] val: 0.3787, [best_batch: 11000] best val: 0.4131



30it [00:08,  3.80it/s]

[current_batch: 16300] loss: 0.297, learning rate: 0.000100


130it [00:34,  3.81it/s]

[current_batch: 16400] loss: 0.844, learning rate: 0.000100
Starting validation ...


131it [00:41,  2.29s/it]

[current_batch: 16400] val: 0.3873, [best_batch: 11000] best val: 0.4131


230it [01:07,  3.81it/s]

[current_batch: 16500] loss: 0.952, learning rate: 0.000100


330it [01:34,  3.81it/s]

[current_batch: 16600] loss: 0.957, learning rate: 0.000100
Starting validation ...


331it [01:41,  2.29s/it]

[current_batch: 16600] val: 0.3855, [best_batch: 11000] best val: 0.4131


430it [02:07,  3.79it/s]

[current_batch: 16700] loss: 0.957, learning rate: 0.000100


492it [02:24,  3.83it/s]

Starting validation ...


493it [02:31,  3.26it/s]

[current_batch: 16762] val: 0.3821, [best_batch: 11000] best val: 0.4131



37it [00:09,  3.83it/s]

[current_batch: 16800] loss: 0.302, learning rate: 0.000100
Starting validation ...


38it [00:16,  2.28s/it]

[current_batch: 16800] val: 0.3769, [best_batch: 11000] best val: 0.4131


137it [00:42,  3.81it/s]

[current_batch: 16900] loss: 0.851, learning rate: 0.000100


237it [01:09,  3.81it/s]

[current_batch: 17000] loss: 0.892, learning rate: 0.000100
Starting validation ...


238it [01:16,  2.29s/it]

[current_batch: 17000] val: 0.3804, [best_batch: 11000] best val: 0.4131


337it [01:42,  3.83it/s]

[current_batch: 17100] loss: 0.899, learning rate: 0.000100


437it [02:09,  3.74it/s]

[current_batch: 17200] loss: 0.940, learning rate: 0.000100
Starting validation ...


438it [02:16,  2.30s/it]

[current_batch: 17200] val: 0.3718, [best_batch: 11000] best val: 0.4131


492it [02:30,  3.77it/s]

Starting validation ...


493it [02:37,  3.14it/s]

[current_batch: 17255] val: 0.3907, [best_batch: 11000] best val: 0.4131



44it [00:11,  3.81it/s]

[current_batch: 17300] loss: 0.422, learning rate: 0.000100


144it [00:38,  3.80it/s]

[current_batch: 17400] loss: 0.870, learning rate: 0.000100
Starting validation ...


145it [00:45,  2.29s/it]

[current_batch: 17400] val: 0.4062, [best_batch: 11000] best val: 0.4131


244it [01:11,  3.76it/s]

[current_batch: 17500] loss: 0.902, learning rate: 0.000100


344it [01:37,  3.82it/s]

[current_batch: 17600] loss: 0.840, learning rate: 0.000100
Starting validation ...


345it [01:44,  2.28s/it]

[current_batch: 17600] val: 0.3701, [best_batch: 11000] best val: 0.4131


444it [02:11,  3.77it/s]

[current_batch: 17700] loss: 0.870, learning rate: 0.000100


492it [02:23,  3.85it/s]

Starting validation ...


493it [02:30,  3.27it/s]

[current_batch: 17748] val: 0.3666, [best_batch: 11000] best val: 0.4131



51it [00:13,  3.84it/s]

[current_batch: 17800] loss: 0.455, learning rate: 0.000100
Starting validation ...


52it [00:20,  2.30s/it]

[current_batch: 17800] val: 0.3666, [best_batch: 11000] best val: 0.4131


151it [00:46,  3.83it/s]

[current_batch: 17900] loss: 0.795, learning rate: 0.000100


251it [01:13,  3.84it/s]

[current_batch: 18000] loss: 0.797, learning rate: 0.000100
Starting validation ...


252it [01:20,  2.28s/it]

[current_batch: 18000] val: 0.3752, [best_batch: 11000] best val: 0.4131


351it [01:45,  3.81it/s]

[current_batch: 18100] loss: 0.869, learning rate: 0.000100


451it [02:12,  3.78it/s]

[current_batch: 18200] loss: 0.824, learning rate: 0.000100
Starting validation ...


452it [02:19,  2.29s/it]

[current_batch: 18200] val: 0.3838, [best_batch: 11000] best val: 0.4131


492it [02:30,  3.82it/s]

Starting validation ...


493it [02:37,  3.14it/s]

[current_batch: 18241] val: 0.3821, [best_batch: 11000] best val: 0.4131



58it [00:15,  3.81it/s]

[current_batch: 18300] loss: 0.472, learning rate: 0.000100


158it [00:41,  3.81it/s]

[current_batch: 18400] loss: 0.855, learning rate: 0.000100
Starting validation ...


159it [00:48,  2.29s/it]

[current_batch: 18400] val: 0.3614, [best_batch: 11000] best val: 0.4131


258it [01:14,  3.82it/s]

[current_batch: 18500] loss: 0.831, learning rate: 0.000100


358it [01:41,  3.77it/s]

[current_batch: 18600] loss: 0.867, learning rate: 0.000100
Starting validation ...


359it [01:48,  2.28s/it]

[current_batch: 18600] val: 0.3855, [best_batch: 11000] best val: 0.4131


458it [02:14,  3.81it/s]

[current_batch: 18700] loss: 0.855, learning rate: 0.000100


492it [02:23,  3.79it/s]

Starting validation ...


493it [02:30,  3.27it/s]

[current_batch: 18734] val: 0.3563, [best_batch: 11000] best val: 0.4131



65it [00:17,  3.74it/s]

[current_batch: 18800] loss: 0.542, learning rate: 0.000100
Starting validation ...


66it [00:24,  2.30s/it]

[current_batch: 18800] val: 0.3683, [best_batch: 11000] best val: 0.4131


165it [00:50,  3.78it/s]

[current_batch: 18900] loss: 0.746, learning rate: 0.000100


265it [01:17,  3.82it/s]

[current_batch: 19000] loss: 0.800, learning rate: 0.000100
Starting validation ...


266it [01:24,  2.29s/it]

[current_batch: 19000] val: 0.3614, [best_batch: 11000] best val: 0.4131


365it [01:50,  3.81it/s]

[current_batch: 19100] loss: 0.779, learning rate: 0.000100


465it [02:16,  3.78it/s]

[current_batch: 19200] loss: 0.836, learning rate: 0.000100
Starting validation ...


466it [02:23,  2.33s/it]

[current_batch: 19200] val: 0.3666, [best_batch: 11000] best val: 0.4131


492it [02:30,  3.87it/s]

Starting validation ...


493it [02:37,  3.13it/s]

[current_batch: 19227] val: 0.3718, [best_batch: 11000] best val: 0.4131



72it [00:18,  3.85it/s]

[current_batch: 19300] loss: 0.573, learning rate: 0.000100


172it [00:45,  3.83it/s]

[current_batch: 19400] loss: 0.762, learning rate: 0.000100
Starting validation ...


173it [00:52,  2.29s/it]

[current_batch: 19400] val: 0.3597, [best_batch: 11000] best val: 0.4131


272it [01:18,  3.77it/s]

[current_batch: 19500] loss: 0.800, learning rate: 0.000100


372it [01:44,  3.82it/s]

[current_batch: 19600] loss: 0.825, learning rate: 0.000100
Starting validation ...


373it [01:51,  2.30s/it]

[current_batch: 19600] val: 0.3546, [best_batch: 11000] best val: 0.4131


472it [02:18,  3.78it/s]

[current_batch: 19700] loss: 0.791, learning rate: 0.000100


492it [02:23,  3.81it/s]

Starting validation ...


493it [02:30,  3.28it/s]

[current_batch: 19720] val: 0.3614, [best_batch: 11000] best val: 0.4131



79it [00:20,  3.75it/s]

[current_batch: 19800] loss: 0.553, learning rate: 0.000100
Starting validation ...


80it [00:27,  2.29s/it]

[current_batch: 19800] val: 0.3718, [best_batch: 11000] best val: 0.4131


179it [00:54,  3.81it/s]

[current_batch: 19900] loss: 0.715, learning rate: 0.000100


279it [01:20,  3.83it/s]

[current_batch: 20000] loss: 0.782, learning rate: 0.000100
Starting validation ...


280it [01:27,  2.29s/it]

[current_batch: 20000] val: 0.3701, [best_batch: 11000] best val: 0.4131


379it [01:53,  3.80it/s]

[current_batch: 20100] loss: 0.744, learning rate: 0.000100


479it [02:20,  3.80it/s]

[current_batch: 20200] loss: 0.798, learning rate: 0.000100
Starting validation ...


480it [02:27,  2.29s/it]

[current_batch: 20200] val: 0.3666, [best_batch: 11000] best val: 0.4131


492it [02:30,  3.44it/s]

Starting validation ...


493it [02:37,  3.13it/s]

[current_batch: 20213] val: 0.3649, [best_batch: 11000] best val: 0.4131



86it [00:22,  3.87it/s]

[current_batch: 20300] loss: 0.639, learning rate: 0.000100


186it [00:49,  3.86it/s]

[current_batch: 20400] loss: 0.685, learning rate: 0.000100
Starting validation ...


187it [00:56,  2.29s/it]

[current_batch: 20400] val: 0.3821, [best_batch: 11000] best val: 0.4131


286it [01:22,  3.85it/s]

[current_batch: 20500] loss: 0.756, learning rate: 0.000100


386it [01:48,  3.81it/s]

[current_batch: 20600] loss: 0.742, learning rate: 0.000100
Starting validation ...


387it [01:55,  2.29s/it]

[current_batch: 20600] val: 0.3787, [best_batch: 11000] best val: 0.4131


486it [02:21,  3.78it/s]

[current_batch: 20700] loss: 0.764, learning rate: 0.000100


492it [02:23,  3.58it/s]

Starting validation ...


493it [02:30,  3.28it/s]

[current_batch: 20706] val: 0.3769, [best_batch: 11000] best val: 0.4131



93it [00:24,  3.82it/s]

[current_batch: 20800] loss: 0.698, learning rate: 0.000100
Starting validation ...


94it [00:31,  2.30s/it]

[current_batch: 20800] val: 0.3718, [best_batch: 11000] best val: 0.4131


193it [00:57,  3.80it/s]

[current_batch: 20900] loss: 0.742, learning rate: 0.000100


293it [01:24,  3.81it/s]

[current_batch: 21000] loss: 0.682, learning rate: 0.000100
Starting validation ...


294it [01:31,  2.29s/it]

[current_batch: 21000] val: 0.3580, [best_batch: 11000] best val: 0.4131


393it [01:57,  3.79it/s]

[current_batch: 21100] loss: 0.715, learning rate: 0.000100


492it [02:23,  3.81it/s]

Starting validation ...


493it [02:30,  3.27it/s]

[current_batch: 21199] val: 0.3787, [best_batch: 11000] best val: 0.4131



0it [00:00, ?it/s]

[current_batch: 21200] loss: 0.007, learning rate: 0.000100
Starting validation ...


1it [00:07,  7.17s/it]

[current_batch: 21200] val: 0.3787, [best_batch: 11000] best val: 0.4131


100it [00:33,  3.78it/s]

[current_batch: 21300] loss: 0.658, learning rate: 0.000100


200it [00:59,  3.79it/s]

[current_batch: 21400] loss: 0.715, learning rate: 0.000100
Starting validation ...


201it [01:06,  2.30s/it]

[current_batch: 21400] val: 0.3683, [best_batch: 11000] best val: 0.4131


300it [01:32,  3.86it/s]

[current_batch: 21500] loss: 0.720, learning rate: 0.000100


400it [01:59,  3.81it/s]

[current_batch: 21600] loss: 0.726, learning rate: 0.000100
Starting validation ...


401it [02:06,  2.29s/it]

[current_batch: 21600] val: 0.3563, [best_batch: 11000] best val: 0.4131


492it [02:30,  3.76it/s]

Starting validation ...


493it [02:36,  3.14it/s]

[current_batch: 21692] val: 0.3632, [best_batch: 11000] best val: 0.4131



7it [00:01,  3.75it/s]

[current_batch: 21700] loss: 0.070, learning rate: 0.000100


107it [00:28,  3.84it/s]

[current_batch: 21800] loss: 0.685, learning rate: 0.000100
Starting validation ...


108it [00:35,  2.28s/it]

[current_batch: 21800] val: 0.3769, [best_batch: 11000] best val: 0.4131


207it [01:01,  3.85it/s]

[current_batch: 21900] loss: 0.751, learning rate: 0.000100


307it [01:27,  3.82it/s]

[current_batch: 22000] loss: 0.697, learning rate: 0.000100
Starting validation ...


308it [01:34,  2.30s/it]

[current_batch: 22000] val: 0.3649, [best_batch: 11000] best val: 0.4131


407it [02:00,  3.86it/s]

[current_batch: 22100] loss: 0.716, learning rate: 0.000100


492it [02:23,  3.84it/s]

Starting validation ...


493it [02:30,  3.28it/s]

[current_batch: 22185] val: 0.3804, [best_batch: 11000] best val: 0.4131



14it [00:03,  3.79it/s]

[current_batch: 22200] loss: 0.110, learning rate: 0.000100
Starting validation ...


15it [00:10,  2.30s/it]

[current_batch: 22200] val: 0.3735, [best_batch: 11000] best val: 0.4131


114it [00:36,  3.83it/s]

[current_batch: 22300] loss: 0.631, learning rate: 0.000100


214it [01:03,  3.83it/s]

[current_batch: 22400] loss: 0.637, learning rate: 0.000100
Starting validation ...


215it [01:10,  2.30s/it]

[current_batch: 22400] val: 0.3718, [best_batch: 11000] best val: 0.4131


314it [01:36,  3.81it/s]

[current_batch: 22500] loss: 0.709, learning rate: 0.000100


414it [02:02,  3.84it/s]

[current_batch: 22600] loss: 0.709, learning rate: 0.000100
Starting validation ...


415it [02:09,  2.29s/it]

[current_batch: 22600] val: 0.3787, [best_batch: 11000] best val: 0.4131


492it [02:29,  3.82it/s]

Starting validation ...


493it [02:36,  3.15it/s]

[current_batch: 22678] val: 0.3787, [best_batch: 11000] best val: 0.4131



21it [00:05,  3.89it/s]

[current_batch: 22700] loss: 0.144, learning rate: 0.000100


121it [00:32,  3.81it/s]

[current_batch: 22800] loss: 0.664, learning rate: 0.000100
Starting validation ...


122it [00:39,  2.29s/it]

[current_batch: 22800] val: 0.3649, [best_batch: 11000] best val: 0.4131


221it [01:05,  3.74it/s]

[current_batch: 22900] loss: 0.664, learning rate: 0.000100


321it [01:31,  3.78it/s]

[current_batch: 23000] loss: 0.671, learning rate: 0.000100
Starting validation ...


322it [01:38,  2.30s/it]

[current_batch: 23000] val: 0.3494, [best_batch: 11000] best val: 0.4131


421it [02:04,  3.81it/s]

[current_batch: 23100] loss: 0.774, learning rate: 0.000100


492it [02:23,  3.77it/s]

Starting validation ...


493it [02:30,  3.27it/s]

[current_batch: 23171] val: 0.3769, [best_batch: 11000] best val: 0.4131



28it [00:07,  3.74it/s]

[current_batch: 23200] loss: 0.189, learning rate: 0.000100
Starting validation ...


29it [00:14,  2.30s/it]

[current_batch: 23200] val: 0.3752, [best_batch: 11000] best val: 0.4131


128it [00:40,  3.79it/s]

[current_batch: 23300] loss: 0.666, learning rate: 0.000100


228it [01:07,  3.82it/s]

[current_batch: 23400] loss: 0.697, learning rate: 0.000100
Starting validation ...


229it [01:14,  2.28s/it]

[current_batch: 23400] val: 0.3683, [best_batch: 11000] best val: 0.4131


328it [01:40,  3.80it/s]

[current_batch: 23500] loss: 0.708, learning rate: 0.000100


428it [02:07,  3.79it/s]

[current_batch: 23600] loss: 0.707, learning rate: 0.000100
Starting validation ...


429it [02:14,  2.29s/it]

[current_batch: 23600] val: 0.3718, [best_batch: 11000] best val: 0.4131


492it [02:30,  3.76it/s]

Starting validation ...


493it [02:37,  3.13it/s]

[current_batch: 23664] val: 0.3735, [best_batch: 11000] best val: 0.4131



35it [00:09,  3.75it/s]

[current_batch: 23700] loss: 0.215, learning rate: 0.000100


135it [00:36,  3.76it/s]

[current_batch: 23800] loss: 0.669, learning rate: 0.000100
Starting validation ...


136it [00:43,  2.30s/it]

[current_batch: 23800] val: 0.3563, [best_batch: 11000] best val: 0.4131


235it [01:09,  3.80it/s]

[current_batch: 23900] loss: 0.646, learning rate: 0.000100


335it [01:35,  3.83it/s]

[current_batch: 24000] loss: 0.729, learning rate: 0.000100
Starting validation ...


336it [01:42,  2.28s/it]

[current_batch: 24000] val: 0.3701, [best_batch: 11000] best val: 0.4131


435it [02:08,  3.81it/s]

[current_batch: 24100] loss: 0.666, learning rate: 0.000100


492it [02:24,  3.79it/s]

Starting validation ...


493it [02:30,  3.27it/s]

[current_batch: 24157] val: 0.3649, [best_batch: 11000] best val: 0.4131



42it [00:11,  3.82it/s]

[current_batch: 24200] loss: 0.311, learning rate: 0.000100
Starting validation ...


43it [00:18,  2.29s/it]

[current_batch: 24200] val: 0.3614, [best_batch: 11000] best val: 0.4131


142it [00:44,  3.81it/s]

[current_batch: 24300] loss: 0.594, learning rate: 0.000100


242it [01:10,  3.81it/s]

[current_batch: 24400] loss: 0.687, learning rate: 0.000100
Starting validation ...


243it [01:17,  2.29s/it]

[current_batch: 24400] val: 0.3442, [best_batch: 11000] best val: 0.4131


342it [01:43,  3.77it/s]

[current_batch: 24500] loss: 0.640, learning rate: 0.000100


442it [02:10,  3.77it/s]

[current_batch: 24600] loss: 0.684, learning rate: 0.000100
Starting validation ...


443it [02:17,  2.29s/it]

[current_batch: 24600] val: 0.3546, [best_batch: 11000] best val: 0.4131


492it [02:30,  3.84it/s]

Starting validation ...


493it [02:37,  3.14it/s]

[current_batch: 24650] val: 0.3528, [best_batch: 11000] best val: 0.4131





# Load the checkpoint that performs best on the validation set

In [10]:
model_ckpt_path = sorted(glob("model/memsum_dqa/*.pt"), key = os.path.getmtime  )[-1]
print("Final model checkpoint path:", model_ckpt_path)

Final model checkpoint path: model/memsum_dqa/model_batch_11000.pt


In [11]:
memsum_pdf_vqa = MemSumPDFVQA( model_ckpt_path,
                               "model/word_embedding/vocabulary_200dim.pkl",
                               0
                             )

# Validate

In [12]:
val_corpus, paper_dict = load_dataset_for_inference( validation_doc_info_path, validation_dataframe_path )
for example in val_corpus:
    prediction = memsum_pdf_vqa.inference(
                                            example["question"],
                                            example["question_type"],
                                            paper_dict[ example["pmcid"] ],
                                            p_stop_thres = 0.2,
                                            max_extracted_sentences_per_document = 4
                                         )
    example["prediction_global_ids"] = set(prediction)

In [13]:
#### question type "parent_relationship_understanding"
label_ids, pred_ids = [], []
for example in val_corpus:
    if example["question_type"] == "parent_relationship_understanding":
        label_ids.append( example["answer_global_ids"] )
        pred_ids.append( example["prediction_global_ids"] )
print( "parent_relationship_understanding:", calculate_exact_match_ratio( label_ids, pred_ids  ) )

#### question type "child_relationship_understanding"
label_ids, pred_ids = [], []
for example in val_corpus:
    if example["question_type"] == "child_relationship_understanding":
        label_ids.append( example["answer_global_ids"] )
        pred_ids.append( example["prediction_global_ids"] )
print( "child_relationship_understanding:", calculate_exact_match_ratio( label_ids, pred_ids  ) )

#### question type overall
label_ids, pred_ids = [], []
for example in val_corpus:
    label_ids.append( example["answer_global_ids"] )
    pred_ids.append( example["prediction_global_ids"] )
print( "overall:", calculate_exact_match_ratio( label_ids, pred_ids  ) )

parent_relationship_understanding: 0.33905579399141633
child_relationship_understanding: 0.7130434782608696
overall: 0.41308089500860584


# Test

In [14]:
corpus, paper_dict = load_dataset_for_inference( test_doc_info_path, test_dataframe_path )
predicted_labels = []
for idx in range(len(corpus)):
    question = corpus[idx]["question"]
    question_type = corpus[idx]["question_type"]
    doc_content = paper_dict[ corpus[idx]["pmcid"] ]
    predicted_labels.append(  memsum_pdf_vqa.inference( question, question_type, doc_content,
                                                            p_stop_thres = 0.2, max_extracted_sentences_per_document = 4 ) )

id_list = range(0, len(predicted_labels))
df = pd.DataFrame(id_list, columns=['id'])
df['answer'] = predicted_labels

df.to_csv('submission.csv', index=False)