In [1]:
import argparse
import json
import random
import time
import os
import warnings
warnings.filterwarnings("ignore")

from copy import deepcopy
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.style.use("ggplot")

import sklearn
from sklearn.model_selection import train_test_split, StratifiedKFold

import transformers
from transformers import BertTokenizer, BertModel, ElectraTokenizer, ElectraModel, AdamW, get_linear_schedule_with_warmup

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Dataset

import torchvision
from torchvision import transforms, datasets

from Utils.dataset import *
from Utils.utils import *
from Models.BertClf import *
from Models.LstmClf import *
from Models.ElectraClf import *
from Models.ConvClf import *

#################################################################################################################
# Library Version
#################################################################################################################
print(f"pandas version: {pd.__version__}")
print(f"numpy version: {np.__version__}")
print(f"seaborn version: {sns.__version__}")
print(f"matplotlib version: {mpl.__version__}")
print(f"sklearn version: {sklearn.__version__}")
print(f"transformers version: {transformers.__version__}")
print(f"torch version: {torch.__version__}")

#################################################################################################################
# Reproducible
#################################################################################################################
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(42)
random.seed(42)
os.environ['PYTHONHASHSEED'] = str(42)

#################################################################################################################
# Hyperparameters Setting
#################################################################################################################
parser = argparse.ArgumentParser()

# model
parser.add_argument('--model', type=str, default='CNN', help='BERT, BILSTM, ELECTRA, CNN')
parser.add_argument('--sent_embedding', type=int, default=0, help='0: CLS, 1: 4-layer concat')
parser.add_argument('--hidden_dim', type=int, default=768, help='BERT or ELECTRA: hidden dimension of classifier, BILSTM: hidden dimension of lstm')
parser.add_argument('--num_layer', type=int, default=2, help='BILSTM: number of layers of lstm')
parser.add_argument("--embedding_dim", type=int, default=256, help='embedding dimension of CNN')
parser.add_argument("--kernel_sizes", nargs='+', default=[3, 4, 5], type=int, help='kernel sizes of CNN')
parser.add_argument("--kernel_depth", default=500, type=int, help='kernel depth of CNN')
parser.add_argument('--dropout', type=float, default=0.5, help='dropout ratio')

# training
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--gpu', type=int, default=0, help='0,1,2,3')
parser.add_argument('--max_epoch', type=int, default=50)
parser.add_argument('--save', type=int, default=1, help='0: false, 1:true')
parser.add_argument('--lr_pretrained', type=float, default=1e-05, help='learning rate, 5e-5, 3e-5 or 2e-5')
parser.add_argument('--lr_clf', type=float, default=0.0001, help='learning rate, 5e-5, 3e-5 or 2e-5')
parser.add_argument('--freeze_pretrained', type=int, default=0, help='0: false, 1:true')
parser.add_argument('--eps', type=float, default=1e-8, help='epsilon for AdamW, 1e-8')
parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay for AdamW, 5e-4')

# dataset
parser.add_argument('--data_path', type=str, default='./Dataset')
parser.add_argument('--save_model_path', type=str, default='./Saved_models')
parser.add_argument('--save_submission_path', type=str, default='./Submissions')
parser.add_argument('--max_len', type=int, default=50, help='max length of the sentence')
parser.add_argument('--aug', type=int, default=0, help='0: false, 1: true(ru)')
parser.add_argument('--split_ratio', type=int, default=1, help='k/10, k in [1,2,3]')
parser.add_argument('--author', type=str, default='jh')


#     opt = parser.parse_args() # in .py env
opt, _ = parser.parse_known_args() # in .ipynb env

#################################################################################################################
# Training Device
#################################################################################################################
device = torch.device("cuda:" + str(opt.gpu)) if torch.cuda.is_available() else torch.device("cpu")
torch.cuda.set_device(device) # change allocation of current GPU
print(f'training device: {device, torch.cuda.get_device_name()}')
curr_time = time.localtime()
signature = f"{opt.author}_{opt.model}_{curr_time.tm_mon}M_{curr_time.tm_mday}D_{curr_time.tm_hour}H_{curr_time.tm_min}M"
opt.signature = signature
print(f'signature: {signature}')
with open('./Saved_models/' + signature + '_opt.txt', 'w') as f:
    json.dump(opt.__dict__, f, indent=2)

pandas version: 1.2.4
numpy version: 1.20.2
seaborn version: 0.11.1
matplotlib version: 3.4.1
sklearn version: 0.24.2
transformers version: 4.5.1
torch version: 1.8.1+cu102
training device: (device(type='cuda', index=0), 'TITAN Xp')
signature: jh_CNN_5M_31D_17H_36M


In [2]:
#################################################################################################################
# Train and Evaluate
#################################################################################################################
def train_fn(model,
             optimizer,
             scheduler,
             loss_fn,
             train_dataloader,
             valid_dataloader=None,
             evaluation=False):
    """
    Train the BertClassifier model with early stop trick.
    
    :param model: untrained model
    :param train_dataloader: dataloader which is obtained by data_load method
    :param valid_dataloader: dataloader which is obtained by data_load method
    :param epochs: opt.max_epoch [int]
    :param evaluation: [bool]
    """
    # Start training loop
    print("Start training...\n")
    es_eval_dict = {
        "epoch": 0,
        "train_loss": 0,
        "valid_loss": 0,
        "valid_acc": 0
    }  # early stop
    for epoch_i in range(opt.max_epoch):
        # =======================================
        #               Training
        # =======================================
        # Print the header of the result table
        print(
            f"{'Epoch':^7} | {'Batch':^7} | {'Train Loss':^12} | {'Val Loss':^10} | {'Val Acc':^9} | {'Elapsed':^9}"
        )
        print("-" * 70)

        # Measure the elapsed time of each epoch
        t0_epoch, t0_batch = time.time(), time.time()

        # Reset tracking variables at the beginning of each epoch
        total_loss, batch_loss, batch_counts = 0, 0, 0

        # Put the model into the training mode
        model.train()

        # For each batch of training data...
        for step, batch in enumerate(train_dataloader):
            batch_counts += 1
            # Load batch to GPU
            b_ids_tsr, b_masks_tsr, b_labels_tsr = tuple(
                tsrs.to(device) for tsrs in batch)

            # Zero out any previously calculated gradients
            model.zero_grad()

            # Perform a forward pass. This will return logits.
            if opt.model in ["BILSTM", "CNN"]:
                logits = model(b_ids_tsr)
            else:
                logits = model(b_ids_tsr, b_masks_tsr)

            # Compute loss and accumulate the loss values
            loss = loss_fn(logits, b_labels_tsr)
            batch_loss += loss.item()
            total_loss += loss.item()

            # Perform a backward pass to calculate gradients
            loss.backward()

            # Clip the norm of the gradients to 1.0 to prevent "exploding gradients"
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            # Update parameters and the learning rate
            optimizer.step()
            scheduler.step()

            # Print the loss values and time elapsed for every 20 batches
            if (step % 20 == 0
                    and step != 0) or (step == len(train_dataloader) - 1):
                # Calculate time elapsed for 20 batches
                time_elapsed = time.time() - t0_batch

                # Print training results
                print(
                    f"{epoch_i + 1:^7} | {step:^7} | {batch_loss / batch_counts:^12.6f} | {'-':^10} | {'-':^9} | {time_elapsed:^9.2f}"
                )

                # Reset batch tracking variables
                batch_loss, batch_counts = 0, 0
                t0_batch = time.time()

        # Calculate the average loss over the entire training data
        avg_train_loss = total_loss / len(train_dataloader)

        print("-" * 70)
        # =======================================
        #               Evaluation
        # =======================================
        model_save_path = str(
            opt.save_model_path) + "/" + opt.signature + '.model'
        if evaluation == True:
            previous_valid_acc = es_eval_dict["valid_acc"]  # early stop
            # After the completion of each training epoch, measure the model's performance
            # on our validation set.
            valid_loss, valid_acc = evaluate_fn(model, loss_fn,
                                                valid_dataloader)

            # Print performance over the entire training data
            time_elapsed = time.time() - t0_epoch
            print(
                f"{epoch_i + 1:^7} | {'-':^7} | {avg_train_loss:^12.6f} | {valid_loss:^10.6f} | {valid_acc:^9.2f} | {time_elapsed:^9.2f}"
            )
            print("-" * 70)
            if previous_valid_acc < valid_acc:
                es_eval_dict["epoch"] = epoch_i
                es_eval_dict["train_loss"] = avg_train_loss
                es_eval_dict["valid_loss"] = valid_loss
                es_eval_dict["valid_acc"] = valid_acc
                if opt.save == 1:
                    torch.save(model.state_dict(), model_save_path)
                    print('\tthe model is improved... save at',
                          model_save_path)
        print("\n")
    print("Final results table")
    print("-" * 70)
    print(
        f"{'Epoch':^7} | {'Batch':^7} | {'Train Loss':^12} | {'Val Loss':^10} | {'Val Acc':^9} | {'Elapsed':^9}"
    )
    final_epoch, final_train_loss, final_valid_loss, final_valid_acc = es_eval_dict[
        "epoch"], es_eval_dict["train_loss"], es_eval_dict[
            "valid_loss"], es_eval_dict["valid_acc"]
    print(
        f"{final_epoch + 1:^7} | {'-':^7} | {final_train_loss:^12.6f} | {final_valid_loss:^10.6f} | {final_valid_acc:^9.2f} | {0:^9.2f}"
    )
    print("-" * 70)
    print("Training complete!")
    return model, final_train_loss, final_valid_loss, final_valid_acc


def evaluate_fn(model, loss_fn, valid_dataloader):
    """
    After the completion of each training epoch, measure the model's performance on our validation set.
    
    :param model: trained model
    :param valid_dataloader: dataloader which is obtained by data_load method
    
    :return valid_loss: validation loss [array]
    :return valid_acc: validation accuracy [array]
    """
    # Put the model into the evaluation mode. The dropout layers are disabled during
    # the test time.
    model.eval()

    # Tracking variables
    valid_acc = []
    valid_loss = []

    # For each batch in our validation set...
    for batch in valid_dataloader:
        # Load batch to GPU
        b_ids_tsr, b_masks_tsr, b_labels_tsr = tuple(
            t.to(device) for t in batch)

        # Compute logits
        with torch.no_grad():
            if opt.model in ["BILSTM", "CNN"]:
                logits = model(b_ids_tsr)
            else:
                logits = model(b_ids_tsr, b_masks_tsr)

        # Compute loss
        loss = loss_fn(logits, b_labels_tsr)
        valid_loss.append(loss.item())

        # Get the predictions
        preds = torch.argmax(logits, dim=1).flatten()

        # Calculate the accuracy rate
        accuracy = (preds == b_labels_tsr).cpu().numpy().mean() * 100
        valid_acc.append(accuracy)

    # Compute the average accuracy and loss over the validation set.
    valid_loss = np.mean(valid_loss)
    valid_acc = np.mean(valid_acc)

    return valid_loss, valid_acc


def cross_validation(full_dataset=None, n_splits=5):
    """Define a cross validation function
    """
    train_loss_list, valid_loss_list, valid_acc_list = [], [], []
    full_ids = full_dataset.ids_tsr.detach().cpu().numpy()
    full_labels = full_dataset.labels.detach().cpu().numpy()
    skf = StratifiedKFold(n_splits=n_splits, shuffle=False)
    for i, idx in enumerate(skf.split(full_ids, full_labels)):
        print(f"Start {i}-th cross validation...\n")
        train_indices, valid_indices = idx[0], idx[1]
        print(train_indices)
        print(valid_indices)

        train_subset = torch.utils.data.dataset.Subset(full_dataset,
                                                       train_indices)
        valid_subset = torch.utils.data.dataset.Subset(full_dataset,
                                                       valid_indices)

        print(
            f"len of train set: {len(train_subset)}, len of valid set: {len(valid_subset)}"
        )
        print()

        train_dataloader = DataLoader(
            train_subset,
            batch_size=opt.batch_size,
            shuffle=True,
        )
        valid_dataloader = DataLoader(
            valid_subset,
            batch_size=opt.batch_size,
            shuffle=True,
        )

        # Specify the loss function
        loss_fn = nn.CrossEntropyLoss()

        # Initialize the model
        untrained_model, optimizer, scheduler = initialize_model(
            opt, len(train_dataloader), device)

        _, train_loss, valid_loss, valid_acc = train_fn(untrained_model,
                                                        optimizer,
                                                        scheduler,
                                                        loss_fn,
                                                        train_dataloader,
                                                        valid_dataloader,
                                                        evaluation=True)

        train_loss_list.append(train_loss)
        valid_loss_list.append(valid_loss)
        valid_acc_list.append(valid_acc)

        print(f"...Complete {i}-th cross validation\n")
    train_loss_arr = np.array(train_loss_list)
    valid_loss_arr = np.array(valid_loss_list)
    valid_acc_arr = np.array(valid_acc_list)
    valid_avg_score = np.mean(valid_acc_arr)
    print("=" * 60)
    print(f"Average valid accuracy: {valid_avg_score}")
    print("=" * 60)
    return train_loss_arr, valid_loss_arr, valid_acc_arr, valid_avg_score

In [3]:
# # k-cross validation
# full_dataset = FullDataset(opt)
# train_loss_arr, valid_loss_arr, valid_acc_arr, valid_avg_score = cross_validation(
#     full_dataset=full_dataset,
#     n_splits=8)

In [4]:
# Load the DataLoaders
train_dataloader, valid_dataloader, test_dataloader = data_load(opt)

# Specify the loss function
loss_fn = nn.CrossEntropyLoss()

# Initialize the model
untrained_model, optimizer, scheduler = initialize_model(opt, len(train_dataloader), device)

trained_model, _, _, _ = train_fn(untrained_model, optimizer, scheduler, loss_fn, train_dataloader, valid_dataloader=valid_dataloader, evaluation=True)

Tokenizing data...
Apply the BertTokenizer...


Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


train_X_ids_tsr.shape: torch.Size([7805, 50])
train_X_masks_tsr.shape: torch.Size([7805, 50])
Tokenizing data...
Apply the BertTokenizer...


Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


valid_X_ids_tsr.shape: torch.Size([1748, 50])
valid_X_masks_tsr.shape: torch.Size([1748, 50])
Tokenizing data...
Apply the BertTokenizer...


Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


test_X_ids_tsr.shape: torch.Size([4311, 50])
test_X_masks_tsr.shape: torch.Size([4311, 50])
num of train_loader: 7805
num of valid_loader: 1748
num of test_loader: 4311
Start training...

 Epoch  |  Batch  |  Train Loss  |  Val Loss  |  Val Acc  |  Elapsed 
----------------------------------------------------------------------
   1    |   20    |   1.708169   |     -      |     -     |   0.18   
   1    |   40    |   1.595107   |     -      |     -     |   0.16   
   1    |   60    |   1.587543   |     -      |     -     |   0.15   
   1    |   80    |   1.650986   |     -      |     -     |   0.15   
   1    |   100   |   1.606898   |     -      |     -     |   0.15   
   1    |   120   |   1.609489   |     -      |     -     |   0.15   
   1    |   140   |   1.604599   |     -      |     -     |   0.16   
   1    |   160   |   1.593350   |     -      |     -     |   0.16   
   1    |   180   |   1.595814   |     -      |     -     |   0.15   
   1    |   200   |   1.642261   |     - 

   4    |   400   |   1.351756   |     -      |     -     |   0.18   
   4    |   420   |   1.427427   |     -      |     -     |   0.18   
   4    |   440   |   1.353783   |     -      |     -     |   0.18   
   4    |   460   |   1.450025   |     -      |     -     |   0.18   
   4    |   480   |   1.442045   |     -      |     -     |   0.19   
   4    |   487   |   1.350683   |     -      |     -     |   0.06   
----------------------------------------------------------------------
   4    |    -    |   1.403147   |  1.397316  |   40.28   |   4.48   
----------------------------------------------------------------------
	the model is improved... save at ./Saved_models/jh_CNN_5M_31D_17H_36M.model


 Epoch  |  Batch  |  Train Loss  |  Val Loss  |  Val Acc  |  Elapsed 
----------------------------------------------------------------------
   5    |   20    |   1.324857   |     -      |     -     |   0.20   
   5    |   40    |   1.332666   |     -      |     -     |   0.18   
   5    

   8    |   280   |   1.132637   |     -      |     -     |   0.18   
   8    |   300   |   1.087464   |     -      |     -     |   0.17   
   8    |   320   |   1.112744   |     -      |     -     |   0.17   
   8    |   340   |   1.141772   |     -      |     -     |   0.18   
   8    |   360   |   1.103894   |     -      |     -     |   0.17   
   8    |   380   |   1.133211   |     -      |     -     |   0.17   
   8    |   400   |   1.106470   |     -      |     -     |   0.17   
   8    |   420   |   1.117282   |     -      |     -     |   0.17   
   8    |   440   |   1.095366   |     -      |     -     |   0.17   
   8    |   460   |   1.205292   |     -      |     -     |   0.17   
   8    |   480   |   1.189328   |     -      |     -     |   0.17   
   8    |   487   |   1.116040   |     -      |     -     |   0.06   
----------------------------------------------------------------------
   8    |    -    |   1.138715   |  1.165120  |   56.93   |   4.35   
-------------------

  12    |   160   |   0.849692   |     -      |     -     |   0.17   
  12    |   180   |   0.914344   |     -      |     -     |   0.17   
  12    |   200   |   0.877858   |     -      |     -     |   0.17   
  12    |   220   |   0.842653   |     -      |     -     |   0.17   
  12    |   240   |   0.893337   |     -      |     -     |   0.17   
  12    |   260   |   0.893742   |     -      |     -     |   0.17   
  12    |   280   |   0.843315   |     -      |     -     |   0.17   
  12    |   300   |   0.787841   |     -      |     -     |   0.17   
  12    |   320   |   0.825400   |     -      |     -     |   0.18   
  12    |   340   |   0.893396   |     -      |     -     |   0.17   
  12    |   360   |   0.811515   |     -      |     -     |   0.17   
  12    |   380   |   0.859290   |     -      |     -     |   0.17   
  12    |   400   |   0.801077   |     -      |     -     |   0.17   
  12    |   420   |   0.905157   |     -      |     -     |   0.17   
  12    |   440   | 

  16    |   40    |   0.657229   |     -      |     -     |   0.17   
  16    |   60    |   0.652914   |     -      |     -     |   0.17   
  16    |   80    |   0.611664   |     -      |     -     |   0.18   
  16    |   100   |   0.540488   |     -      |     -     |   0.17   
  16    |   120   |   0.654863   |     -      |     -     |   0.17   
  16    |   140   |   0.609513   |     -      |     -     |   0.18   
  16    |   160   |   0.593880   |     -      |     -     |   0.17   
  16    |   180   |   0.592213   |     -      |     -     |   0.17   
  16    |   200   |   0.583769   |     -      |     -     |   0.17   
  16    |   220   |   0.555560   |     -      |     -     |   0.17   
  16    |   240   |   0.616493   |     -      |     -     |   0.17   
  16    |   260   |   0.670545   |     -      |     -     |   0.17   
  16    |   280   |   0.629149   |     -      |     -     |   0.17   
  16    |   300   |   0.511292   |     -      |     -     |   0.17   
  16    |   320   | 

	the model is improved... save at ./Saved_models/jh_CNN_5M_31D_17H_36M.model


 Epoch  |  Batch  |  Train Loss  |  Val Loss  |  Val Acc  |  Elapsed 
----------------------------------------------------------------------
  20    |   20    |   0.382026   |     -      |     -     |   0.19   
  20    |   40    |   0.410938   |     -      |     -     |   0.17   
  20    |   60    |   0.465795   |     -      |     -     |   0.17   
  20    |   80    |   0.409907   |     -      |     -     |   0.18   
  20    |   100   |   0.429868   |     -      |     -     |   0.17   
  20    |   120   |   0.458326   |     -      |     -     |   0.17   
  20    |   140   |   0.398467   |     -      |     -     |   0.17   
  20    |   160   |   0.378875   |     -      |     -     |   0.18   
  20    |   180   |   0.409187   |     -      |     -     |   0.17   
  20    |   200   |   0.417772   |     -      |     -     |   0.17   
  20    |   220   |   0.411222   |     -      |     -     |   0.17   
  20    | 

  23    |   480   |   0.309110   |     -      |     -     |   0.17   
  23    |   487   |   0.275613   |     -      |     -     |   0.06   
----------------------------------------------------------------------
  23    |    -    |   0.313351   |  0.855304  |   70.91   |   4.35   
----------------------------------------------------------------------


 Epoch  |  Batch  |  Train Loss  |  Val Loss  |  Val Acc  |  Elapsed 
----------------------------------------------------------------------
  24    |   20    |   0.256495   |     -      |     -     |   0.16   
  24    |   40    |   0.303736   |     -      |     -     |   0.17   
  24    |   60    |   0.310158   |     -      |     -     |   0.17   
  24    |   80    |   0.296585   |     -      |     -     |   0.17   
  24    |   100   |   0.241332   |     -      |     -     |   0.17   
  24    |   120   |   0.334760   |     -      |     -     |   0.17   
  24    |   140   |   0.312198   |     -      |     -     |   0.17   
  24    |   160

  27    |   420   |   0.212194   |     -      |     -     |   0.17   
  27    |   440   |   0.207299   |     -      |     -     |   0.17   
  27    |   460   |   0.254591   |     -      |     -     |   0.17   
  27    |   480   |   0.248497   |     -      |     -     |   0.17   
  27    |   487   |   0.163153   |     -      |     -     |   0.06   
----------------------------------------------------------------------
  27    |    -    |   0.217229   |  0.877459  |   70.85   |   4.31   
----------------------------------------------------------------------


 Epoch  |  Batch  |  Train Loss  |  Val Loss  |  Val Acc  |  Elapsed 
----------------------------------------------------------------------
  28    |   20    |   0.157429   |     -      |     -     |   0.17   
  28    |   40    |   0.202456   |     -      |     -     |   0.17   
  28    |   60    |   0.209739   |     -      |     -     |   0.17   
  28    |   80    |   0.196520   |     -      |     -     |   0.17   
  28    |   100

  31    |   320   |   0.142053   |     -      |     -     |   0.17   
  31    |   340   |   0.142560   |     -      |     -     |   0.17   
  31    |   360   |   0.085070   |     -      |     -     |   0.17   
  31    |   380   |   0.179448   |     -      |     -     |   0.17   
  31    |   400   |   0.132725   |     -      |     -     |   0.17   
  31    |   420   |   0.143118   |     -      |     -     |   0.17   
  31    |   440   |   0.145695   |     -      |     -     |   0.17   
  31    |   460   |   0.150210   |     -      |     -     |   0.17   
  31    |   480   |   0.152392   |     -      |     -     |   0.18   
  31    |   487   |   0.153602   |     -      |     -     |   0.06   
----------------------------------------------------------------------
  31    |    -    |   0.145736   |  0.922367  |   71.25   |   4.35   
----------------------------------------------------------------------


 Epoch  |  Batch  |  Train Loss  |  Val Loss  |  Val Acc  |  Elapsed 
----------------

  35    |   240   |   0.094794   |     -      |     -     |   0.17   
  35    |   260   |   0.097164   |     -      |     -     |   0.17   
  35    |   280   |   0.114016   |     -      |     -     |   0.17   
  35    |   300   |   0.076654   |     -      |     -     |   0.18   
  35    |   320   |   0.126117   |     -      |     -     |   0.17   
  35    |   340   |   0.123828   |     -      |     -     |   0.17   
  35    |   360   |   0.102859   |     -      |     -     |   0.17   
  35    |   380   |   0.107297   |     -      |     -     |   0.17   
  35    |   400   |   0.090729   |     -      |     -     |   0.17   
  35    |   420   |   0.098251   |     -      |     -     |   0.17   
  35    |   440   |   0.083054   |     -      |     -     |   0.17   
  35    |   460   |   0.101433   |     -      |     -     |   0.17   
  35    |   480   |   0.101663   |     -      |     -     |   0.17   
  35    |   487   |   0.100722   |     -      |     -     |   0.06   
--------------------

  39    |   180   |   0.097794   |     -      |     -     |   0.17   
  39    |   200   |   0.079487   |     -      |     -     |   0.17   
  39    |   220   |   0.070078   |     -      |     -     |   0.17   
  39    |   240   |   0.048695   |     -      |     -     |   0.17   
  39    |   260   |   0.085382   |     -      |     -     |   0.17   
  39    |   280   |   0.078460   |     -      |     -     |   0.17   
  39    |   300   |   0.083368   |     -      |     -     |   0.17   
  39    |   320   |   0.071782   |     -      |     -     |   0.17   
  39    |   340   |   0.078055   |     -      |     -     |   0.17   
  39    |   360   |   0.065001   |     -      |     -     |   0.17   
  39    |   380   |   0.102930   |     -      |     -     |   0.17   
  39    |   400   |   0.077327   |     -      |     -     |   0.17   
  39    |   420   |   0.102957   |     -      |     -     |   0.17   
  39    |   440   |   0.071188   |     -      |     -     |   0.17   
  39    |   460   | 

  43    |   140   |   0.104091   |     -      |     -     |   0.17   
  43    |   160   |   0.057177   |     -      |     -     |   0.17   
  43    |   180   |   0.054403   |     -      |     -     |   0.17   
  43    |   200   |   0.053054   |     -      |     -     |   0.17   
  43    |   220   |   0.066055   |     -      |     -     |   0.17   
  43    |   240   |   0.043956   |     -      |     -     |   0.17   
  43    |   260   |   0.039401   |     -      |     -     |   0.17   
  43    |   280   |   0.066990   |     -      |     -     |   0.17   
  43    |   300   |   0.043728   |     -      |     -     |   0.17   
  43    |   320   |   0.070135   |     -      |     -     |   0.17   
  43    |   340   |   0.084080   |     -      |     -     |   0.18   
  43    |   360   |   0.063331   |     -      |     -     |   0.17   
  43    |   380   |   0.058124   |     -      |     -     |   0.17   
  43    |   400   |   0.035601   |     -      |     -     |   0.17   
  43    |   420   | 

  47    |   60    |   0.083626   |     -      |     -     |   0.17   
  47    |   80    |   0.051492   |     -      |     -     |   0.17   
  47    |   100   |   0.048527   |     -      |     -     |   0.17   
  47    |   120   |   0.063537   |     -      |     -     |   0.17   
  47    |   140   |   0.045873   |     -      |     -     |   0.17   
  47    |   160   |   0.026192   |     -      |     -     |   0.17   
  47    |   180   |   0.025271   |     -      |     -     |   0.17   
  47    |   200   |   0.056327   |     -      |     -     |   0.17   
  47    |   220   |   0.036933   |     -      |     -     |   0.17   
  47    |   240   |   0.037450   |     -      |     -     |   0.17   
  47    |   260   |   0.046590   |     -      |     -     |   0.17   
  47    |   280   |   0.055109   |     -      |     -     |   0.17   
  47    |   300   |   0.048716   |     -      |     -     |   0.17   
  47    |   320   |   0.028076   |     -      |     -     |   0.17   
  47    |   340   | 

In [5]:
full_dataloader = data_load(opt, flag="full")
untrained_model, optimizer, scheduler = initialize_model(opt, len(train_dataloader), device)
full_trained_model, _, _, _ = train_fn(untrained_model, optimizer, scheduler, loss_fn, full_dataloader, evaluation=False)
model_save_path = str(opt.save_model_path) + "/" + opt.signature +'_full.model'
if opt.save == 1: 
    torch.save(full_trained_model.state_dict(), model_save_path)
    print('\tthe model is improved... save at', model_save_path)

Tokenizing data...
Apply the BertTokenizer...


Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


full_X_ids_tsr.shape: torch.Size([8563, 50])
full_X_masks_tsr.shape: torch.Size([8563, 50])
Start training...

 Epoch  |  Batch  |  Train Loss  |  Val Loss  |  Val Acc  |  Elapsed 
----------------------------------------------------------------------
   1    |   20    |   1.730783   |     -      |     -     |   0.20   
   1    |   40    |   1.606503   |     -      |     -     |   0.19   
   1    |   60    |   1.688982   |     -      |     -     |   0.16   
   1    |   80    |   1.648260   |     -      |     -     |   0.18   
   1    |   100   |   1.621243   |     -      |     -     |   0.19   
   1    |   120   |   1.618031   |     -      |     -     |   0.19   
   1    |   140   |   1.635443   |     -      |     -     |   0.18   
   1    |   160   |   1.614198   |     -      |     -     |   0.19   
   1    |   180   |   1.589714   |     -      |     -     |   0.18   
   1    |   200   |   1.583488   |     -      |     -     |   0.18   
   1    |   220   |   1.597896   |     -      | 

   4    |   520   |   1.403237   |     -      |     -     |   0.19   
   4    |   535   |   1.422029   |     -      |     -     |   0.13   
----------------------------------------------------------------------


 Epoch  |  Batch  |  Train Loss  |  Val Loss  |  Val Acc  |  Elapsed 
----------------------------------------------------------------------
   5    |   20    |   1.307340   |     -      |     -     |   0.19   
   5    |   40    |   1.377474   |     -      |     -     |   0.18   
   5    |   60    |   1.366840   |     -      |     -     |   0.19   
   5    |   80    |   1.400770   |     -      |     -     |   0.19   
   5    |   100   |   1.359424   |     -      |     -     |   0.18   
   5    |   120   |   1.362783   |     -      |     -     |   0.18   
   5    |   140   |   1.371508   |     -      |     -     |   0.19   
   5    |   160   |   1.339382   |     -      |     -     |   0.18   
   5    |   180   |   1.309885   |     -      |     -     |   0.19   
   5    |   200 

   8    |   460   |   1.081888   |     -      |     -     |   0.19   
   8    |   480   |   1.153872   |     -      |     -     |   0.17   
   8    |   500   |   1.085797   |     -      |     -     |   0.19   
   8    |   520   |   1.150797   |     -      |     -     |   0.20   
   8    |   535   |   1.126930   |     -      |     -     |   0.14   
----------------------------------------------------------------------


 Epoch  |  Batch  |  Train Loss  |  Val Loss  |  Val Acc  |  Elapsed 
----------------------------------------------------------------------
   9    |   20    |   1.019882   |     -      |     -     |   0.21   
   9    |   40    |   1.085736   |     -      |     -     |   0.19   
   9    |   60    |   1.093574   |     -      |     -     |   0.19   
   9    |   80    |   1.163461   |     -      |     -     |   0.19   
   9    |   100   |   1.110903   |     -      |     -     |   0.17   
   9    |   120   |   1.129966   |     -      |     -     |   0.20   
   9    |   140 

  12    |   420   |   0.820629   |     -      |     -     |   0.19   
  12    |   440   |   0.893056   |     -      |     -     |   0.18   
  12    |   460   |   0.835828   |     -      |     -     |   0.19   
  12    |   480   |   0.862636   |     -      |     -     |   0.19   
  12    |   500   |   0.798658   |     -      |     -     |   0.19   
  12    |   520   |   0.890386   |     -      |     -     |   0.19   
  12    |   535   |   0.938933   |     -      |     -     |   0.14   
----------------------------------------------------------------------


 Epoch  |  Batch  |  Train Loss  |  Val Loss  |  Val Acc  |  Elapsed 
----------------------------------------------------------------------
  13    |   20    |   0.728163   |     -      |     -     |   0.20   
  13    |   40    |   0.813877   |     -      |     -     |   0.20   
  13    |   60    |   0.817461   |     -      |     -     |   0.17   
  13    |   80    |   0.918868   |     -      |     -     |   0.20   
  13    |   100 

  16    |   380   |   0.689516   |     -      |     -     |   0.18   
  16    |   400   |   0.553778   |     -      |     -     |   0.18   
  16    |   420   |   0.596195   |     -      |     -     |   0.18   
  16    |   440   |   0.638169   |     -      |     -     |   0.18   
  16    |   460   |   0.558780   |     -      |     -     |   0.18   
  16    |   480   |   0.605209   |     -      |     -     |   0.17   
  16    |   500   |   0.544849   |     -      |     -     |   0.19   
  16    |   520   |   0.607784   |     -      |     -     |   0.18   
  16    |   535   |   0.607231   |     -      |     -     |   0.13   
----------------------------------------------------------------------


 Epoch  |  Batch  |  Train Loss  |  Val Loss  |  Val Acc  |  Elapsed 
----------------------------------------------------------------------
  17    |   20    |   0.515008   |     -      |     -     |   0.18   
  17    |   40    |   0.565310   |     -      |     -     |   0.18   
  17    |   60  

  20    |   340   |   0.383285   |     -      |     -     |   0.18   
  20    |   360   |   0.434991   |     -      |     -     |   0.17   
  20    |   380   |   0.470917   |     -      |     -     |   0.17   
  20    |   400   |   0.362880   |     -      |     -     |   0.17   
  20    |   420   |   0.391040   |     -      |     -     |   0.17   
  20    |   440   |   0.449728   |     -      |     -     |   0.17   
  20    |   460   |   0.375685   |     -      |     -     |   0.18   
  20    |   480   |   0.453518   |     -      |     -     |   0.17   
  20    |   500   |   0.381449   |     -      |     -     |   0.17   
  20    |   520   |   0.434431   |     -      |     -     |   0.17   
  20    |   535   |   0.423220   |     -      |     -     |   0.13   
----------------------------------------------------------------------


 Epoch  |  Batch  |  Train Loss  |  Val Loss  |  Val Acc  |  Elapsed 
----------------------------------------------------------------------
  21    |   20  

  24    |   300   |   0.335954   |     -      |     -     |   0.17   
  24    |   320   |   0.271892   |     -      |     -     |   0.17   
  24    |   340   |   0.232569   |     -      |     -     |   0.17   
  24    |   360   |   0.273301   |     -      |     -     |   0.17   
  24    |   380   |   0.306439   |     -      |     -     |   0.18   
  24    |   400   |   0.279423   |     -      |     -     |   0.18   
  24    |   420   |   0.299427   |     -      |     -     |   0.17   
  24    |   440   |   0.333903   |     -      |     -     |   0.17   
  24    |   460   |   0.298291   |     -      |     -     |   0.17   
  24    |   480   |   0.268971   |     -      |     -     |   0.17   
  24    |   500   |   0.242430   |     -      |     -     |   0.17   
  24    |   520   |   0.273117   |     -      |     -     |   0.17   
  24    |   535   |   0.313474   |     -      |     -     |   0.13   
----------------------------------------------------------------------


 Epoch  |  Batch 

  28    |   260   |   0.222145   |     -      |     -     |   0.17   
  28    |   280   |   0.229807   |     -      |     -     |   0.17   
  28    |   300   |   0.226322   |     -      |     -     |   0.17   
  28    |   320   |   0.246254   |     -      |     -     |   0.17   
  28    |   340   |   0.151567   |     -      |     -     |   0.18   
  28    |   360   |   0.180701   |     -      |     -     |   0.17   
  28    |   380   |   0.234953   |     -      |     -     |   0.17   
  28    |   400   |   0.200806   |     -      |     -     |   0.17   
  28    |   420   |   0.186726   |     -      |     -     |   0.17   
  28    |   440   |   0.228516   |     -      |     -     |   0.17   
  28    |   460   |   0.252211   |     -      |     -     |   0.17   
  28    |   480   |   0.186073   |     -      |     -     |   0.18   
  28    |   500   |   0.181911   |     -      |     -     |   0.17   
  28    |   520   |   0.220179   |     -      |     -     |   0.17   
  28    |   535   | 

  32    |   220   |   0.112812   |     -      |     -     |   0.15   
  32    |   240   |   0.206734   |     -      |     -     |   0.17   
  32    |   260   |   0.154727   |     -      |     -     |   0.18   
  32    |   280   |   0.154769   |     -      |     -     |   0.17   
  32    |   300   |   0.183534   |     -      |     -     |   0.17   
  32    |   320   |   0.174243   |     -      |     -     |   0.17   
  32    |   340   |   0.110981   |     -      |     -     |   0.17   
  32    |   360   |   0.134342   |     -      |     -     |   0.17   
  32    |   380   |   0.157411   |     -      |     -     |   0.17   
  32    |   400   |   0.148316   |     -      |     -     |   0.17   
  32    |   420   |   0.122116   |     -      |     -     |   0.17   
  32    |   440   |   0.148917   |     -      |     -     |   0.17   
  32    |   460   |   0.169553   |     -      |     -     |   0.17   
  32    |   480   |   0.116383   |     -      |     -     |   0.17   
  32    |   500   | 

  36    |   180   |   0.110034   |     -      |     -     |   0.18   
  36    |   200   |   0.158611   |     -      |     -     |   0.17   
  36    |   220   |   0.096950   |     -      |     -     |   0.17   
  36    |   240   |   0.117356   |     -      |     -     |   0.17   
  36    |   260   |   0.117811   |     -      |     -     |   0.17   
  36    |   280   |   0.103549   |     -      |     -     |   0.17   
  36    |   300   |   0.122053   |     -      |     -     |   0.17   
  36    |   320   |   0.120273   |     -      |     -     |   0.17   
  36    |   340   |   0.098863   |     -      |     -     |   0.17   
  36    |   360   |   0.121656   |     -      |     -     |   0.17   
  36    |   380   |   0.088447   |     -      |     -     |   0.17   
  36    |   400   |   0.115967   |     -      |     -     |   0.17   
  36    |   420   |   0.127668   |     -      |     -     |   0.17   
  36    |   440   |   0.116400   |     -      |     -     |   0.17   
  36    |   460   | 

  40    |   140   |   0.075891   |     -      |     -     |   0.17   
  40    |   160   |   0.066403   |     -      |     -     |   0.17   
  40    |   180   |   0.064310   |     -      |     -     |   0.18   
  40    |   200   |   0.118841   |     -      |     -     |   0.17   
  40    |   220   |   0.092962   |     -      |     -     |   0.17   
  40    |   240   |   0.091247   |     -      |     -     |   0.17   
  40    |   260   |   0.121225   |     -      |     -     |   0.17   
  40    |   280   |   0.096044   |     -      |     -     |   0.17   
  40    |   300   |   0.113063   |     -      |     -     |   0.17   
  40    |   320   |   0.077878   |     -      |     -     |   0.17   
  40    |   340   |   0.072154   |     -      |     -     |   0.17   
  40    |   360   |   0.075660   |     -      |     -     |   0.18   
  40    |   380   |   0.097486   |     -      |     -     |   0.17   
  40    |   400   |   0.078828   |     -      |     -     |   0.17   
  40    |   420   | 

  44    |   100   |   0.110475   |     -      |     -     |   0.17   
  44    |   120   |   0.074325   |     -      |     -     |   0.17   
  44    |   140   |   0.085229   |     -      |     -     |   0.18   
  44    |   160   |   0.078210   |     -      |     -     |   0.17   
  44    |   180   |   0.050383   |     -      |     -     |   0.17   
  44    |   200   |   0.076322   |     -      |     -     |   0.17   
  44    |   220   |   0.067966   |     -      |     -     |   0.17   
  44    |   240   |   0.053344   |     -      |     -     |   0.17   
  44    |   260   |   0.078381   |     -      |     -     |   0.17   
  44    |   280   |   0.061105   |     -      |     -     |   0.17   
  44    |   300   |   0.065296   |     -      |     -     |   0.17   
  44    |   320   |   0.079139   |     -      |     -     |   0.17   
  44    |   340   |   0.057201   |     -      |     -     |   0.17   
  44    |   360   |   0.080726   |     -      |     -     |   0.17   
  44    |   380   | 

  48    |   60    |   0.076014   |     -      |     -     |   0.18   
  48    |   80    |   0.055788   |     -      |     -     |   0.17   
  48    |   100   |   0.083221   |     -      |     -     |   0.17   
  48    |   120   |   0.054099   |     -      |     -     |   0.17   
  48    |   140   |   0.041121   |     -      |     -     |   0.17   
  48    |   160   |   0.041242   |     -      |     -     |   0.17   
  48    |   180   |   0.040412   |     -      |     -     |   0.18   
  48    |   200   |   0.086838   |     -      |     -     |   0.17   
  48    |   220   |   0.059357   |     -      |     -     |   0.17   
  48    |   240   |   0.041161   |     -      |     -     |   0.17   
  48    |   260   |   0.067457   |     -      |     -     |   0.17   
  48    |   280   |   0.069516   |     -      |     -     |   0.17   
  48    |   300   |   0.044191   |     -      |     -     |   0.17   
  48    |   320   |   0.035217   |     -      |     -     |   0.17   
  48    |   340   | 

In [6]:
# Save the submission file
# make_submission(trained_model, opt, device, test_dataloader)
# make_submission(trained_model, opt, device, test_dataloader, full=True)

In [7]:
# Print the number of parameters
numOfparams(trained_model)

9358637