In [1]:
import os
#Set a specific GPU

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "0"  # Set the GPU 0 to use

In [2]:
import torch
import transformers

from transformers import BertTokenizer
from transformers import BertForQuestionAnswering
from transformers import get_cosine_with_hard_restarts_schedule_with_warmup
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.optim import AdamW
from sklearn.model_selection import train_test_split

import pandas as pd
import numpy as np
import random
import time
import datetime
import re
import pickle
from tqdm.notebook import tqdm

In [3]:
import wandb

wandb.init(project='test_jlk_nlp_qa_n')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjun171[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print('Device:', device)
print('Current cuda device:', torch.cuda.current_device())
print('Count of using GPUs:', torch.cuda.device_count())

Device: cuda
Current cuda device: 0
Count of using GPUs: 1


In [5]:
main_loc = '../'

with open(main_loc+'data/encoded_data_nl_s_fin.pkl','rb') as f:
    inputs_o, starts_o, ends_o, len_qs_o, masks_o = pickle.load(f)


In [6]:
# inputs_o, _, starts_o, _, ends_o, _, masks_o, _ = \
#     train_test_split(inputs_o, starts_o, ends_o, masks_o, random_state=1, test_size=0.98)

In [7]:
# with open(main_loc+'data/encoded_data_nl_small.pkl','wb') as f:
#     pickle.dump([inputs_o, starts_o, ends_o, masks_o],f)

In [8]:
train_inputs, test_inputs, train_starts, test_starts, train_ends, test_ends, train_masks, test_masks = \
    train_test_split(inputs_o, starts_o, ends_o, masks_o, random_state=1, test_size=0.3)
    
train_inputs, validation_inputs, train_starts, validation_starts, train_ends, validation_ends, train_masks,validation_masks = \
    train_test_split(train_inputs, train_starts, train_ends, train_masks, random_state=0, test_size=0.1)

In [9]:
root_address = main_loc+'model/'

In [10]:
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

In [11]:
# i = 0
# print(tokenizer.decode(inputs_o[i][starts_o[i]:ends_o[i]+1]))
# print(tokenizer.decode(inputs_o[i][len_qs_o[i][0]:len_qs_o[i][1]]))

In [12]:
train_inputs = torch.tensor(train_inputs)
train_starts = torch.tensor(train_starts)
train_ends = torch.tensor(train_ends)
train_masks = torch.tensor(train_masks)

validation_inputs = torch.tensor(validation_inputs)
validation_starts = torch.tensor(validation_starts)
validation_ends = torch.tensor(validation_ends)
validation_masks = torch.tensor(validation_masks)

test_inputs = torch.tensor(test_inputs)
test_starts = torch.tensor(test_starts)
test_ends = torch.tensor(test_ends)
test_masks = torch.tensor(test_masks)

In [13]:
BATCH_SIZE = 24 #24

train_data = TensorDataset(train_inputs, train_masks, train_starts, train_ends)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=BATCH_SIZE)

validation_data = TensorDataset(validation_inputs, validation_masks, validation_starts, validation_ends)
validation_sampler = SequentialSampler(validation_data)
validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=BATCH_SIZE)

test_data = TensorDataset(test_inputs, test_masks, test_starts, test_ends)
test_sampler = RandomSampler(test_data)
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=BATCH_SIZE)

In [14]:
# model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

# model.to(device.type)

In [15]:
model = torch.load(root_address+'model_max_acc.pt')

model.to(device.type)

BertForQuestionAnswering(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024,), eps=1e-12,

In [16]:
optimizer = AdamW(model.parameters(),
                  lr =1e-5, #3e-5
                  eps = 1e-12
                )

epochs = 100


total_steps = len(train_dataloader) * epochs


scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, 
                                            num_warmup_steps = 0,
                                            num_training_steps = total_steps,
                                            num_cycles=20
                                            )

In [17]:
def cal_dis(f,id):
    rst = 0
    
    if abs(f-id)<5:
        rst+=1/((abs(f-id)+1)**2)
    else:
        pass
    
    return rst

def calc_ans(flats,ids):
    rst = 0
    
    for f,id in zip(flats,ids):
        rst+=cal_dis(f,id)

    return rst

# def calc_ans_bt(st,end,st_ids,end_ids):
#     rst = 0
    
#     for sf,ef,s_id,e_id in zip(st, end, st_ids, end_ids):
#         rst+= cal_dis(sf,s_id)*cal_dis(ef,e_id)
    
#     return rst

def is_same_text(sf,ef,s_id,e_id,input_id):
    re_t = tokenizer.decode(input_id[s_id:e_id+1]).strip()
    ex_t = tokenizer.decode(input_id[sf:ef+1]).strip()
    if re_t == ex_t:    
        return True
    else:
        return False

def calc_ans_bt(st,end,st_ids,end_ids,input_ids):
    rst = 0
    
    for sf,ef,s_id,e_id,input_id in zip(st, end, st_ids, end_ids,input_ids):
        if is_same_text(sf,ef,s_id,e_id,input_id):
            rst+=1

    return rst

def print_expect_val(s_e,e_e,s_i,e_i,input_id):
    re_t = tokenizer.decode(input_id[s_i:e_i+1])
    print(f'Real text: {re_t}')
    if s_e<=e_e:
        ex_t = tokenizer.decode(input_id[s_e:e_e+1])
        print(f'Expect text: {ex_t}')
    else:
        print('No Expect text')
    print()

# def flatten_proper_ans_set(st_s,end_s,n):
#     st_set = []
#     end_set = []
#     for st_p, end_p in zip(st_s,end_s):
#         st = st_p[-n:][0]
#         end = end_p[-n:][0]
#         st_set.append(st)
#         end_set.append(end)
    
#     return st_set,end_set


def flat_accuracy(st_ps,end_ps,st_ids,end_ids,input_ids, pr_t = False,k_n = 3):
    # st_flats = np.argsort(st_ps,kth = -k_n, axis=1)
    # end_flats = np.argsort(end_ps,kth = -k_n, axis=1)
    
    st_flat = np.argmax(st_ps, axis=1).flatten()
    end_flat = np.argmax(end_ps, axis=1).flatten()
    
    # st_acc = np.sum(st_flat == st_ids) / len(st_ids)
    # end_acc = np.sum(end_flat == end_ids) / len(end_ids)    
    # acc = np.sum((st_flat == st_ids) * (end_flat==end_ids)) / len(st_ids)   
    
    if pr_t: #check expected values
        for s_e,e_e,s_i,e_i,input_id in zip(st_flat,end_flat,st_ids,end_ids,input_ids):      
            if is_same_text(s_e,e_e,s_i,e_i,input_id)==False:
                print(f'Expect: {s_e},{e_e}. Real: {s_i},{e_i}')
                print_expect_val(s_e,e_e,s_i,e_i,input_id)
            
    
    st_acc = calc_ans(st_flat, st_ids) / len(st_ids)
    end_acc = calc_ans(end_flat, end_ids) / len(end_ids)
    
    acc = calc_ans_bt(st_flat,end_flat,st_ids,end_ids,input_ids) / len(st_ids)

    return st_acc,end_acc,acc

def format_time(elapsed):

    elapsed_rounded = int(round((elapsed)))
    return str(datetime.timedelta(seconds=elapsed_rounded))

In [18]:
seed_val = 0
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

model.zero_grad()

max_acc = 0

for epoch_i in range(0, epochs):
    print(f'Training...\n======== Epoch {epoch_i + 1} / {epochs} ========')
    
    t0 = time.time()
    total_loss = 0
    prev_loss = 0
    prev_step = 0
    
    model.train()
    
    for step, batch in enumerate(tqdm(train_dataloader)):

        if step % 500 == 0 and not step == 0:
            elapsed = format_time(time.time() - t0)
            step_loss = (total_loss-prev_loss) / (step-prev_step)
            print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(train_dataloader), elapsed))
            print(f'  Average loss =  {step_loss}')
            print(f"  Last learning rate: {scheduler.get_last_lr()[0]:.3}")
            
            wandb.log({"step loss": step_loss}, step=epoch_i*len(train_dataloader)+step)
            prev_loss = total_loss
            prev_step = step
                        
        if step % 3000 == 0 and not step == 0:
            torch.save(model,root_address+f'model_sv_{epoch_i}_on.pt')
            print(f'model_sv_{epoch_i}_on has been saved')
            

        batch = tuple(t.to(device) for t in batch)
        
        b_input_ids, b_input_mask, b_starts, b_ends = batch

        optimizer.zero_grad()

        # Forward         
        outputs = model(b_input_ids, 
                        attention_mask = b_input_mask,
                        start_positions = b_starts,
                        end_positions = b_ends)
        
        st_logits = outputs['start_logits']
        # print('st_log',st_logits)
        # print('st_log_max',max(st_logits[0]))
        end_logits = outputs['end_logits']
        # print('end_log',end_logits)
        
        loss = outputs[0]
        # print('loss', loss)

        total_loss += loss.item()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()

        scheduler.step()


    avg_train_loss = total_loss / len(train_dataloader)            

    print("")
    print("  Average training loss: {:}".format(avg_train_loss))
    print("  Training epoch took: {:}".format(format_time(time.time() - t0)))
    print(f"  Last learning rate: {scheduler.get_last_lr()[0]:.3}")
    print("")        
    # ========================================
    #               Validation
    # ========================================

    print("Running Validation...")


    t0 = time.time()

    torch.save(model,root_address+f'model_sv_{epoch_i}.pt')
    print(f'model_sv_{epoch_i} has been saved')

    if epoch_i >=2:
        os.remove(root_address+f'model_sv_{epoch_i-2}.pt')
        os.remove(root_address+f'model_sv_{epoch_i-2}_on.pt')
    #Evaluation
    model.eval()

    eval_loss, eval_accuracy1,eval_accuracy2,eval_accuracy3 = 0, 0, 0, 0
    nb_eval_steps = 0

    for batch in tqdm(validation_dataloader):

        batch = tuple(t.to(device) for t in batch)
        

        b_input_ids, b_input_mask, b_starts, b_ends = batch

        with torch.no_grad():     
            outputs = model(b_input_ids, 
                            token_type_ids=None, 
                            attention_mask=b_input_mask)
        
        st_logits = outputs['start_logits']
        end_logits = outputs['end_logits']

        st_logits = st_logits.detach().cpu().numpy()
        end_logits = end_logits.detach().cpu().numpy()
        start_ids = b_starts.to('cpu').numpy()
        end_ids = b_ends.to('cpu').numpy()
        
        
        tmp_eval_accuracy1,tmp_eval_accuracy2,tmp_eval_accuracy3 = flat_accuracy(st_logits,end_logits, start_ids, end_ids,input_ids = b_input_ids)
        eval_accuracy1 += tmp_eval_accuracy1
        eval_accuracy2 += tmp_eval_accuracy2
        eval_accuracy3 += tmp_eval_accuracy3

        nb_eval_steps += 1
        
    acc_all = eval_accuracy3/nb_eval_steps

    print("  st_token Accuracy: {0:.2f}".format(eval_accuracy1/nb_eval_steps))
    print("  end_token Accuracy: {0:.2f}".format(eval_accuracy2/nb_eval_steps))
    print("  all Accuracy: {0:.2f}".format(acc_all))
    print("  Validation took: {:}".format(format_time(time.time() - t0)))
    
    wandb.log({"val_accuracy": acc_all,"loss": avg_train_loss, "learning_rate": scheduler.get_last_lr()[0]}, step=epoch_i)
    
    if max_acc < acc_all:
        print("")
        print(f"Accuracy_nonzero: {acc_all:0.2f} is higher than previous maximum accuracy: {max_acc:0.2f}")
        max_acc = acc_all
        torch.save(model,root_address+'model_max_acc.pt')
        print(f'New model_max_acc has been saved')    
    
    print("")

    # ========================================
    #                  Test
    # ========================================
    if epoch_i%5 ==0 and epoch_i !=0:    
        print("Running Test...")

        t0 = time.time()
            
        model.eval()

        eval_loss, eval_accuracy1,eval_accuracy2,eval_accuracy3 = 0, 0, 0, 0
        nb_eval_steps = 0

        for batch in tqdm(test_dataloader):

            batch = tuple(t.to(device) for t in batch)

            b_input_ids, b_input_mask, b_starts, b_ends = batch

            with torch.no_grad():     
                outputs = model(b_input_ids, 
                                token_type_ids=None, 
                                attention_mask=b_input_mask)
            
            st_logits = outputs['start_logits']
            end_logits = outputs['end_logits']

            st_logits = st_logits.detach().cpu().numpy()
            end_logits = end_logits.detach().cpu().numpy()
            start_ids = b_starts.to('cpu').numpy()
            end_ids = b_ends.to('cpu').numpy()
            
            
            tmp_eval_accuracy1,tmp_eval_accuracy2,tmp_eval_accuracy3 = flat_accuracy(st_logits,end_logits, start_ids, end_ids, input_ids=b_input_ids)
            eval_accuracy1 += tmp_eval_accuracy1
            eval_accuracy2 += tmp_eval_accuracy2
            eval_accuracy3 += tmp_eval_accuracy3

            nb_eval_steps += 1
            
        acc_all = eval_accuracy3/nb_eval_steps
        
        wandb.log({"test_accuracy": acc_all}, step=epoch_i)
        
        print("  st_token Accuracy: {0:.2f}".format(eval_accuracy1/nb_eval_steps))
        print("  end_token Accuracy: {0:.2f}".format(eval_accuracy2/nb_eval_steps))
        print("  all Accuracy: {0:.2f}".format(acc_all))
        print("  Test took: {:}".format(format_time(time.time() - t0)))

        print("")
    
print("Training complete!")

torch.save(model,root_address+'model_fin.pt')

Training...


  0%|          | 0/5641 [00:00<?, ?it/s]

  Batch   500  of  5,641.    Elapsed: 0:14:11.
  Average loss =  0.7725441969037056
  Last learning rate: 9.99e-06
  Batch 1,000  of  5,641.    Elapsed: 0:28:18.
  Average loss =  0.7500127480626106
  Last learning rate: 9.97e-06
  Batch 1,500  of  5,641.    Elapsed: 0:42:26.
  Average loss =  0.7436734752058983
  Last learning rate: 9.93e-06
  Batch 2,000  of  5,641.    Elapsed: 0:56:34.
  Average loss =  0.7672943090200424
  Last learning rate: 9.88e-06
  Batch 2,500  of  5,641.    Elapsed: 1:10:43.
  Average loss =  0.7592478284835815
  Last learning rate: 9.81e-06
  Batch 3,000  of  5,641.    Elapsed: 1:24:52.
  Average loss =  0.7728655940592289
  Last learning rate: 9.72e-06
model_sv_0_on has been saved
  Batch 3,500  of  5,641.    Elapsed: 1:39:57.
  Average loss =  0.7638937972187996
  Last learning rate: 9.62e-06
  Batch 4,000  of  5,641.    Elapsed: 1:54:07.
  Average loss =  0.7839522058963776
  Last learning rate: 9.51e-06
  Batch 4,500  of  5,641.    Elapsed: 2:08:18.
  Av

  0%|          | 0/627 [00:00<?, ?it/s]

  st_token Accuracy: 0.60
  end_token Accuracy: 0.60
  all Accuracy: 0.58
  Validation took: 0:06:22

Accuracy_nonzero: 0.58 is higher than previous maximum accuracy: 0.00
New model_max_acc has been saved

Training...


  0%|          | 0/5641 [00:00<?, ?it/s]

  Batch   500  of  5,641.    Elapsed: 0:14:10.
  Average loss =  0.7375637495517731
  Last learning rate: 8.88e-06
  Batch 1,000  of  5,641.    Elapsed: 0:28:20.
  Average loss =  0.7362869591116905
  Last learning rate: 8.69e-06
  Batch 1,500  of  5,641.    Elapsed: 0:42:29.
  Average loss =  0.7469235206246376
  Last learning rate: 8.5e-06
  Batch 2,000  of  5,641.    Elapsed: 0:56:38.
  Average loss =  0.7400409332066774
  Last learning rate: 8.3e-06
  Batch 2,500  of  5,641.    Elapsed: 1:10:46.
  Average loss =  0.7533619495034218
  Last learning rate: 8.08e-06
  Batch 3,000  of  5,641.    Elapsed: 1:24:56.
  Average loss =  0.7474973101317882
  Last learning rate: 7.86e-06
model_sv_1_on has been saved
  Batch 3,500  of  5,641.    Elapsed: 1:39:18.
  Average loss =  0.7640580903887749
  Last learning rate: 7.62e-06
  Batch 4,000  of  5,641.    Elapsed: 1:53:29.
  Average loss =  0.7548687813282013
  Last learning rate: 7.38e-06
  Batch 4,500  of  5,641.    Elapsed: 2:07:39.
  Aver

  0%|          | 0/627 [00:00<?, ?it/s]

  st_token Accuracy: 0.60
  end_token Accuracy: 0.60
  all Accuracy: 0.57
  Validation took: 0:06:21

Training...


  0%|          | 0/5641 [00:00<?, ?it/s]

  Batch   500  of  5,641.    Elapsed: 0:14:09.
  Average loss =  0.7346827134490013
  Last learning rate: 6.28e-06
  Batch 1,000  of  5,641.    Elapsed: 0:28:16.
  Average loss =  0.7254240836501121
  Last learning rate: 6.01e-06
  Batch 1,500  of  5,641.    Elapsed: 0:42:24.
  Average loss =  0.7338742761015892
  Last learning rate: 5.73e-06
  Batch 2,000  of  5,641.    Elapsed: 0:56:32.
  Average loss =  0.7534537640213966
  Last learning rate: 5.46e-06
  Batch 2,500  of  5,641.    Elapsed: 1:10:40.
  Average loss =  0.7402179727554321
  Last learning rate: 5.18e-06
  Batch 3,000  of  5,641.    Elapsed: 1:24:47.
  Average loss =  0.739889804661274
  Last learning rate: 4.9e-06
model_sv_2_on has been saved
  Batch 3,500  of  5,641.    Elapsed: 1:39:08.
  Average loss =  0.7453472906649112
  Last learning rate: 4.62e-06
  Batch 4,000  of  5,641.    Elapsed: 1:53:15.
  Average loss =  0.7562623009085655
  Last learning rate: 4.34e-06
  Batch 4,500  of  5,641.    Elapsed: 2:07:23.
  Aver

  0%|          | 0/627 [00:00<?, ?it/s]

  st_token Accuracy: 0.59
  end_token Accuracy: 0.60
  all Accuracy: 0.58
  Validation took: 0:06:20

Training...


  0%|          | 0/5641 [00:00<?, ?it/s]

  Batch   500  of  5,641.    Elapsed: 0:14:08.
  Average loss =  0.7292264721393585
  Last learning rate: 3.19e-06
  Batch 1,000  of  5,641.    Elapsed: 0:28:16.
  Average loss =  0.7324612098932266
  Last learning rate: 2.94e-06
  Batch 1,500  of  5,641.    Elapsed: 0:42:24.
  Average loss =  0.7454685433506966
  Last learning rate: 2.69e-06
  Batch 2,000  of  5,641.    Elapsed: 0:56:32.
  Average loss =  0.7487250540852547
  Last learning rate: 2.44e-06
  Batch 2,500  of  5,641.    Elapsed: 1:10:40.
  Average loss =  0.7430379568338394
  Last learning rate: 2.21e-06
  Batch 3,000  of  5,641.    Elapsed: 1:24:47.
  Average loss =  0.7462230242490768
  Last learning rate: 1.98e-06
model_sv_3_on has been saved
  Batch 3,500  of  5,641.    Elapsed: 1:39:07.
  Average loss =  0.7583167610168458
  Last learning rate: 1.76e-06
  Batch 4,000  of  5,641.    Elapsed: 1:53:14.
  Average loss =  0.758340856730938
  Last learning rate: 1.56e-06
  Batch 4,500  of  5,641.    Elapsed: 2:07:20.
  Ave

  0%|          | 0/627 [00:00<?, ?it/s]

  st_token Accuracy: 0.59
  end_token Accuracy: 0.60
  all Accuracy: 0.57
  Validation took: 0:06:20

Training...


  0%|          | 0/5641 [00:00<?, ?it/s]

  Batch   500  of  5,641.    Elapsed: 0:14:07.
  Average loss =  0.7587836250066757
  Last learning rate: 7.98e-07
  Batch 1,000  of  5,641.    Elapsed: 0:28:14.
  Average loss =  0.7452387846112252
  Last learning rate: 6.53e-07
  Batch 1,500  of  5,641.    Elapsed: 0:42:20.
  Average loss =  0.7475883154273033
  Last learning rate: 5.22e-07
  Batch 2,000  of  5,641.    Elapsed: 0:56:24.
  Average loss =  0.7436265663504601
  Last learning rate: 4.06e-07
  Batch 2,500  of  5,641.    Elapsed: 1:10:26.
  Average loss =  0.745103566467762
  Last learning rate: 3.03e-07
  Batch 3,000  of  5,641.    Elapsed: 1:24:27.
  Average loss =  0.7500061175823212
  Last learning rate: 2.15e-07
model_sv_4_on has been saved
  Batch 3,500  of  5,641.    Elapsed: 1:38:51.
  Average loss =  0.7609812605977059
  Last learning rate: 1.42e-07
  Batch 4,000  of  5,641.    Elapsed: 1:52:54.
  Average loss =  0.7518181788921356
  Last learning rate: 8.33e-08
  Batch 4,500  of  5,641.    Elapsed: 2:06:55.
  Ave

  0%|          | 0/627 [00:00<?, ?it/s]

  st_token Accuracy: 0.59
  end_token Accuracy: 0.59
  all Accuracy: 0.57
  Validation took: 0:06:19

Training...


  0%|          | 0/5641 [00:00<?, ?it/s]

  Batch   500  of  5,641.    Elapsed: 0:14:03.
  Average loss =  0.6847846153378486
  Last learning rate: 9.99e-06
  Batch 1,000  of  5,641.    Elapsed: 0:28:05.
  Average loss =  0.7214593588709831
  Last learning rate: 9.97e-06
  Batch 1,500  of  5,641.    Elapsed: 0:42:07.
  Average loss =  0.7163117345273494
  Last learning rate: 9.93e-06
  Batch 2,000  of  5,641.    Elapsed: 0:56:09.
  Average loss =  0.7242327590584755
  Last learning rate: 9.88e-06
  Batch 2,500  of  5,641.    Elapsed: 1:10:10.
  Average loss =  0.7343321932554245
  Last learning rate: 9.81e-06
  Batch 3,000  of  5,641.    Elapsed: 1:24:10.
  Average loss =  0.7380776501297951
  Last learning rate: 9.72e-06
model_sv_5_on has been saved
  Batch 3,500  of  5,641.    Elapsed: 1:39:07.
  Average loss =  0.749914744079113
  Last learning rate: 9.62e-06
  Batch 4,000  of  5,641.    Elapsed: 1:53:09.
  Average loss =  0.7530464440584183
  Last learning rate: 9.51e-06
  Batch 4,500  of  5,641.    Elapsed: 2:07:11.
  Ave

  0%|          | 0/627 [00:00<?, ?it/s]

  st_token Accuracy: 0.59
  end_token Accuracy: 0.59
  all Accuracy: 0.57
  Validation took: 0:06:19

Running Test...


  0%|          | 0/2687 [00:00<?, ?it/s]

  st_token Accuracy: 0.60
  end_token Accuracy: 0.60
  all Accuracy: 0.58
  Test took: 0:26:13

Training...


  0%|          | 0/5641 [00:00<?, ?it/s]

  Batch   500  of  5,641.    Elapsed: 0:14:02.
  Average loss =  0.7721881268024444
  Last learning rate: 8.88e-06
  Batch 1,000  of  5,641.    Elapsed: 0:28:04.
  Average loss =  0.7780948332250118
  Last learning rate: 8.69e-06
  Batch 1,500  of  5,641.    Elapsed: 0:42:04.
  Average loss =  0.7809246983528138
  Last learning rate: 8.5e-06
  Batch 2,000  of  5,641.    Elapsed: 0:56:05.
  Average loss =  0.7689284391403198
  Last learning rate: 8.3e-06
  Batch 2,500  of  5,641.    Elapsed: 1:10:06.
  Average loss =  0.7842823544740677
  Last learning rate: 8.08e-06
  Batch 3,000  of  5,641.    Elapsed: 1:24:08.
  Average loss =  0.7900302202403545
  Last learning rate: 7.86e-06
model_sv_6_on has been saved
  Batch 3,500  of  5,641.    Elapsed: 1:38:22.
  Average loss =  0.7854645727872849
  Last learning rate: 7.62e-06
  Batch 4,000  of  5,641.    Elapsed: 1:52:25.
  Average loss =  0.7871225236654281
  Last learning rate: 7.38e-06
  Batch 4,500  of  5,641.    Elapsed: 2:06:27.
  Aver

  0%|          | 0/627 [00:00<?, ?it/s]

  st_token Accuracy: 0.59
  end_token Accuracy: 0.60
  all Accuracy: 0.57
  Validation took: 0:06:19

Training...


  0%|          | 0/5641 [00:00<?, ?it/s]

  Batch   500  of  5,641.    Elapsed: 0:14:02.
  Average loss =  0.7597169224619865
  Last learning rate: 6.28e-06
  Batch 1,000  of  5,641.    Elapsed: 0:28:04.
  Average loss =  0.7574184155464172
  Last learning rate: 6.01e-06
  Batch 1,500  of  5,641.    Elapsed: 0:42:06.
  Average loss =  0.7713719552159309
  Last learning rate: 5.73e-06
  Batch 2,000  of  5,641.    Elapsed: 0:56:07.
  Average loss =  0.7796014271378517
  Last learning rate: 5.46e-06
  Batch 2,500  of  5,641.    Elapsed: 1:10:10.
  Average loss =  0.7789711985588074
  Last learning rate: 5.18e-06
  Batch 3,000  of  5,641.    Elapsed: 1:24:11.
  Average loss =  0.7844129683971405
  Last learning rate: 4.9e-06
model_sv_7_on has been saved
  Batch 3,500  of  5,641.    Elapsed: 1:38:26.
  Average loss =  0.7686680991649628
  Last learning rate: 4.62e-06
  Batch 4,000  of  5,641.    Elapsed: 1:52:28.
  Average loss =  0.7581752816438675
  Last learning rate: 4.34e-06
  Batch 4,500  of  5,641.    Elapsed: 2:06:30.
  Ave

  0%|          | 0/627 [00:00<?, ?it/s]

  st_token Accuracy: 0.59
  end_token Accuracy: 0.59
  all Accuracy: 0.58
  Validation took: 0:06:20

Training...


  0%|          | 0/5641 [00:00<?, ?it/s]

  Batch   500  of  5,641.    Elapsed: 0:14:02.
  Average loss =  0.7400427107810974
  Last learning rate: 3.19e-06
  Batch 1,000  of  5,641.    Elapsed: 0:28:04.
  Average loss =  0.7594538595676422
  Last learning rate: 2.94e-06
  Batch 1,500  of  5,641.    Elapsed: 0:42:05.
  Average loss =  0.7553800438046455
  Last learning rate: 2.69e-06
  Batch 2,000  of  5,641.    Elapsed: 0:56:06.
  Average loss =  0.7566396964788437
  Last learning rate: 2.44e-06
  Batch 2,500  of  5,641.    Elapsed: 1:10:07.
  Average loss =  0.7484689727425575
  Last learning rate: 2.21e-06
  Batch 3,000  of  5,641.    Elapsed: 1:24:09.
  Average loss =  0.7619136061668396
  Last learning rate: 1.98e-06
model_sv_8_on has been saved
  Batch 3,500  of  5,641.    Elapsed: 1:38:24.
  Average loss =  0.7628930844664573
  Last learning rate: 1.76e-06
  Batch 4,000  of  5,641.    Elapsed: 1:52:27.
  Average loss =  0.7544398324489594
  Last learning rate: 1.56e-06
  Batch 4,500  of  5,641.    Elapsed: 2:06:29.
  Av

  0%|          | 0/627 [00:00<?, ?it/s]

  st_token Accuracy: 0.59
  end_token Accuracy: 0.59
  all Accuracy: 0.58
  Validation took: 0:06:20

Training...


  0%|          | 0/5641 [00:00<?, ?it/s]

  Batch   500  of  5,641.    Elapsed: 0:14:03.
  Average loss =  0.7329920762777329
  Last learning rate: 7.98e-07
  Batch 1,000  of  5,641.    Elapsed: 0:28:05.
  Average loss =  0.7489016938209534
  Last learning rate: 6.53e-07
  Batch 1,500  of  5,641.    Elapsed: 0:42:08.
  Average loss =  0.7341307832002639
  Last learning rate: 5.22e-07
  Batch 2,000  of  5,641.    Elapsed: 0:56:17.
  Average loss =  0.7528952392935753
  Last learning rate: 4.06e-07
  Batch 2,500  of  5,641.    Elapsed: 1:10:25.
  Average loss =  0.7511899518072606
  Last learning rate: 3.03e-07
  Batch 3,000  of  5,641.    Elapsed: 1:24:33.
  Average loss =  0.7418562126755714
  Last learning rate: 2.15e-07
model_sv_9_on has been saved
  Batch 3,500  of  5,641.    Elapsed: 1:38:55.
  Average loss =  0.7434219599366189
  Last learning rate: 1.42e-07
  Batch 4,000  of  5,641.    Elapsed: 1:53:03.
  Average loss =  0.7644230405688286
  Last learning rate: 8.33e-08
  Batch 4,500  of  5,641.    Elapsed: 2:07:11.
  Av

  0%|          | 0/627 [00:00<?, ?it/s]

  st_token Accuracy: 0.59
  end_token Accuracy: 0.59
  all Accuracy: 0.57
  Validation took: 0:06:21

Training...


  0%|          | 0/5641 [00:00<?, ?it/s]

  Batch   500  of  5,641.    Elapsed: 0:14:09.
  Average loss =  0.7660327640771866
  Last learning rate: 9.99e-06
  Batch 1,000  of  5,641.    Elapsed: 0:28:18.
  Average loss =  0.766852909386158
  Last learning rate: 9.97e-06
  Batch 1,500  of  5,641.    Elapsed: 0:42:27.
  Average loss =  0.768072758436203
  Last learning rate: 9.93e-06
  Batch 2,000  of  5,641.    Elapsed: 0:56:35.
  Average loss =  0.7881638633012772
  Last learning rate: 9.88e-06
  Batch 2,500  of  5,641.    Elapsed: 1:10:43.
  Average loss =  0.7729074205160141
  Last learning rate: 9.81e-06
  Batch 3,000  of  5,641.    Elapsed: 1:24:53.
  Average loss =  0.7842452843785286
  Last learning rate: 9.72e-06
model_sv_10_on has been saved
  Batch 3,500  of  5,641.    Elapsed: 1:39:26.
  Average loss =  0.7941295467615127
  Last learning rate: 9.62e-06
  Batch 4,000  of  5,641.    Elapsed: 1:53:35.
  Average loss =  0.7849389755725861
  Last learning rate: 9.51e-06
  Batch 4,500  of  5,641.    Elapsed: 2:07:44.
  Ave

  0%|          | 0/627 [00:00<?, ?it/s]

  st_token Accuracy: 0.59
  end_token Accuracy: 0.59
  all Accuracy: 0.57
  Validation took: 0:06:33

Running Test...


  0%|          | 0/2687 [00:00<?, ?it/s]

  st_token Accuracy: 0.59
  end_token Accuracy: 0.59
  all Accuracy: 0.57
  Test took: 0:26:22

Training...


  0%|          | 0/5641 [00:00<?, ?it/s]

  Batch   500  of  5,641.    Elapsed: 0:14:08.
  Average loss =  0.7669549003243447
  Last learning rate: 8.88e-06
  Batch 1,000  of  5,641.    Elapsed: 0:28:16.
  Average loss =  0.7713912106752395
  Last learning rate: 8.69e-06
  Batch 1,500  of  5,641.    Elapsed: 0:42:23.
  Average loss =  0.7655382486581802
  Last learning rate: 8.5e-06
  Batch 2,000  of  5,641.    Elapsed: 0:56:33.
  Average loss =  0.7774509361386299
  Last learning rate: 8.3e-06
  Batch 2,500  of  5,641.    Elapsed: 1:10:42.
  Average loss =  0.7807841214537621
  Last learning rate: 8.08e-06
  Batch 3,000  of  5,641.    Elapsed: 1:24:50.
  Average loss =  0.7837914015054703
  Last learning rate: 7.86e-06
model_sv_11_on has been saved
  Batch 3,500  of  5,641.    Elapsed: 1:39:23.
  Average loss =  0.7776070953011512
  Last learning rate: 7.62e-06
  Batch 4,000  of  5,641.    Elapsed: 1:53:33.
  Average loss =  0.7726114713549614
  Last learning rate: 7.38e-06
  Batch 4,500  of  5,641.    Elapsed: 2:07:42.
  Ave

  0%|          | 0/627 [00:00<?, ?it/s]

  st_token Accuracy: 0.59
  end_token Accuracy: 0.59
  all Accuracy: 0.57
  Validation took: 0:07:14

Training...


  0%|          | 0/5641 [00:00<?, ?it/s]

  Batch   500  of  5,641.    Elapsed: 0:14:10.
  Average loss =  0.7601217958331108
  Last learning rate: 6.28e-06
  Batch 1,000  of  5,641.    Elapsed: 0:28:19.
  Average loss =  0.7627727025449276
  Last learning rate: 6.01e-06
  Batch 1,500  of  5,641.    Elapsed: 0:42:27.
  Average loss =  0.7451244171261787
  Last learning rate: 5.73e-06
  Batch 2,000  of  5,641.    Elapsed: 0:56:37.
  Average loss =  0.7567648948132992
  Last learning rate: 5.46e-06
  Batch 2,500  of  5,641.    Elapsed: 1:10:45.
  Average loss =  0.7704249260425567
  Last learning rate: 5.18e-06
  Batch 3,000  of  5,641.    Elapsed: 1:24:54.
  Average loss =  0.7675686801075935
  Last learning rate: 4.9e-06
model_sv_12_on has been saved
  Batch 3,500  of  5,641.    Elapsed: 1:39:17.
  Average loss =  0.756743400156498
  Last learning rate: 4.62e-06
  Batch 4,000  of  5,641.    Elapsed: 1:53:27.
  Average loss =  0.7640956081151963
  Last learning rate: 4.34e-06
  Batch 4,500  of  5,641.    Elapsed: 2:07:35.
  Ave

  0%|          | 0/627 [00:00<?, ?it/s]

  st_token Accuracy: 0.59
  end_token Accuracy: 0.59
  all Accuracy: 0.57
  Validation took: 0:06:21

Training...


  0%|          | 0/5641 [00:00<?, ?it/s]

  Batch   500  of  5,641.    Elapsed: 0:14:09.
  Average loss =  0.7415168458223342
  Last learning rate: 3.19e-06
  Batch 1,000  of  5,641.    Elapsed: 0:28:18.
  Average loss =  0.7409661186933517
  Last learning rate: 2.94e-06
  Batch 1,500  of  5,641.    Elapsed: 0:42:27.
  Average loss =  0.7676552320122719
  Last learning rate: 2.69e-06
  Batch 2,000  of  5,641.    Elapsed: 0:56:36.
  Average loss =  0.7463097547888756
  Last learning rate: 2.44e-06
  Batch 2,500  of  5,641.    Elapsed: 1:10:45.
  Average loss =  0.7462800906300545
  Last learning rate: 2.21e-06
  Batch 3,000  of  5,641.    Elapsed: 1:24:53.
  Average loss =  0.7541936403512954
  Last learning rate: 1.98e-06
model_sv_13_on has been saved
  Batch 3,500  of  5,641.    Elapsed: 1:39:15.
  Average loss =  0.7416354954838753
  Last learning rate: 1.76e-06
  Batch 4,000  of  5,641.    Elapsed: 1:53:24.
  Average loss =  0.7496568002104759
  Last learning rate: 1.56e-06
  Batch 4,500  of  5,641.    Elapsed: 2:07:32.
  A

  0%|          | 0/627 [00:00<?, ?it/s]

  st_token Accuracy: 0.59
  end_token Accuracy: 0.59
  all Accuracy: 0.57
  Validation took: 0:06:21

Training...


  0%|          | 0/5641 [00:00<?, ?it/s]

  Batch   500  of  5,641.    Elapsed: 0:14:09.
  Average loss =  0.7375157542824745
  Last learning rate: 7.98e-07
  Batch 1,000  of  5,641.    Elapsed: 0:28:18.
  Average loss =  0.7327121524810791
  Last learning rate: 6.53e-07
  Batch 1,500  of  5,641.    Elapsed: 0:42:26.
  Average loss =  0.7370313926339149
  Last learning rate: 5.22e-07
  Batch 2,000  of  5,641.    Elapsed: 0:56:35.
  Average loss =  0.7299596942663192
  Last learning rate: 4.06e-07
  Batch 2,500  of  5,641.    Elapsed: 1:10:43.
  Average loss =  0.7411395879387855
  Last learning rate: 3.03e-07
  Batch 3,000  of  5,641.    Elapsed: 1:24:48.
  Average loss =  0.7522118650078774
  Last learning rate: 2.15e-07
model_sv_14_on has been saved
  Batch 3,500  of  5,641.    Elapsed: 1:39:08.
  Average loss =  0.7354809569716454
  Last learning rate: 1.42e-07
  Batch 4,000  of  5,641.    Elapsed: 1:53:15.
  Average loss =  0.747198430120945
  Last learning rate: 8.33e-08
  Batch 4,500  of  5,641.    Elapsed: 2:07:21.
  Av

  0%|          | 0/627 [00:00<?, ?it/s]

  st_token Accuracy: 0.59
  end_token Accuracy: 0.59
  all Accuracy: 0.57
  Validation took: 0:06:20

Training...


  0%|          | 0/5641 [00:00<?, ?it/s]

  Batch   500  of  5,641.    Elapsed: 0:14:07.
  Average loss =  0.7520213257074356
  Last learning rate: 9.99e-06
  Batch 1,000  of  5,641.    Elapsed: 0:28:14.
  Average loss =  0.756963781774044
  Last learning rate: 9.97e-06
  Batch 1,500  of  5,641.    Elapsed: 0:42:21.
  Average loss =  0.7660348912775516
  Last learning rate: 9.93e-06
  Batch 2,000  of  5,641.    Elapsed: 0:56:26.
  Average loss =  0.7679939370155334
  Last learning rate: 9.88e-06
  Batch 2,500  of  5,641.    Elapsed: 1:10:28.
  Average loss =  0.7729093528985977
  Last learning rate: 9.81e-06
  Batch 3,000  of  5,641.    Elapsed: 1:24:30.
  Average loss =  0.7751827995181084
  Last learning rate: 9.72e-06
model_sv_15_on has been saved
  Batch 3,500  of  5,641.    Elapsed: 1:38:44.
  Average loss =  0.7707837264537811
  Last learning rate: 9.62e-06
  Batch 4,000  of  5,641.    Elapsed: 1:52:46.
  Average loss =  0.7809451416134834
  Last learning rate: 9.51e-06
  Batch 4,500  of  5,641.    Elapsed: 2:06:47.
  Av

  0%|          | 0/627 [00:00<?, ?it/s]

  st_token Accuracy: 0.59
  end_token Accuracy: 0.59
  all Accuracy: 0.57
  Validation took: 0:06:20

Running Test...


  0%|          | 0/2687 [00:00<?, ?it/s]

  st_token Accuracy: 0.59
  end_token Accuracy: 0.59
  all Accuracy: 0.57
  Test took: 0:26:16

Training...


  0%|          | 0/5641 [00:00<?, ?it/s]

  Batch   500  of  5,641.    Elapsed: 0:14:02.
  Average loss =  0.7666589644551277
  Last learning rate: 8.88e-06
  Batch 1,000  of  5,641.    Elapsed: 0:28:03.
  Average loss =  0.7645531268119812
  Last learning rate: 8.69e-06
  Batch 1,500  of  5,641.    Elapsed: 0:42:04.
  Average loss =  0.7616048728525638
  Last learning rate: 8.5e-06
  Batch 2,000  of  5,641.    Elapsed: 0:56:06.
  Average loss =  0.7698432173132896
  Last learning rate: 8.3e-06
  Batch 2,500  of  5,641.    Elapsed: 1:10:06.
  Average loss =  0.7719360879063606
  Last learning rate: 8.08e-06
  Batch 3,000  of  5,641.    Elapsed: 1:24:09.
  Average loss =  0.7625606741309165
  Last learning rate: 7.86e-06
model_sv_16_on has been saved
  Batch 3,500  of  5,641.    Elapsed: 1:38:23.
  Average loss =  0.7652093903422356
  Last learning rate: 7.62e-06
  Batch 4,000  of  5,641.    Elapsed: 1:52:26.
  Average loss =  0.7581919471621513
  Last learning rate: 7.38e-06
  Batch 4,500  of  5,641.    Elapsed: 2:06:28.
  Ave

  0%|          | 0/627 [00:00<?, ?it/s]

  st_token Accuracy: 0.59
  end_token Accuracy: 0.59
  all Accuracy: 0.57
  Validation took: 0:06:32

Training...


  0%|          | 0/5641 [00:00<?, ?it/s]

  Batch   500  of  5,641.    Elapsed: 0:14:03.
  Average loss =  0.7649317989945412
  Last learning rate: 6.28e-06
  Batch 1,000  of  5,641.    Elapsed: 0:28:05.
  Average loss =  0.7499041575491429
  Last learning rate: 6.01e-06
  Batch 1,500  of  5,641.    Elapsed: 0:42:08.
  Average loss =  0.7675535350143909
  Last learning rate: 5.73e-06
  Batch 2,000  of  5,641.    Elapsed: 0:56:10.
  Average loss =  0.7461438678205013
  Last learning rate: 5.46e-06
  Batch 2,500  of  5,641.    Elapsed: 1:10:12.
  Average loss =  0.7546912826895714
  Last learning rate: 5.18e-06


KeyboardInterrupt: 

In [None]:
# ========================================
#                  Test
# ========================================
print("Running Test...")

t0 = time.time()
    
model.eval()

eval_loss, eval_accuracy1,eval_accuracy2,eval_accuracy3 = 0, 0, 0, 0
nb_eval_steps = 0

for batch in tqdm(test_dataloader):

    batch = tuple(t.to(device) for t in batch)
    

    b_input_ids, b_input_mask, b_starts, b_ends = batch

    with torch.no_grad():     
        outputs = model(b_input_ids, 
                        token_type_ids=None, 
                        attention_mask=b_input_mask)
    
    st_logits = outputs['start_logits']
    end_logits = outputs['end_logits']

    st_logits = st_logits.detach().cpu().numpy()
    end_logits = end_logits.detach().cpu().numpy()
    start_ids = b_starts.to('cpu').numpy()
    end_ids = b_ends.to('cpu').numpy()
    
    
    tmp_eval_accuracy1,tmp_eval_accuracy2,tmp_eval_accuracy3 = flat_accuracy(st_logits,end_logits, start_ids, end_ids, input_ids=b_input_ids)
    eval_accuracy1 += tmp_eval_accuracy1
    eval_accuracy2 += tmp_eval_accuracy2
    eval_accuracy3 += tmp_eval_accuracy3

    nb_eval_steps += 1
    
acc_all = eval_accuracy3/nb_eval_steps
wandb.log({"End_test_accuracy": acc_all})

print("  st_token Accuracy: {0:.2f}".format(eval_accuracy1/nb_eval_steps))
print("  end_token Accuracy: {0:.2f}".format(eval_accuracy2/nb_eval_steps))
print("  all Accuracy: {0:.2f}".format(acc_all))
print("  Test took: {:}".format(format_time(time.time() - t0)))

print("")