In [None]:
# Desired Imports
import torch
import tqdm
from tqdm import trange
from transformers import (AdamW, GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, get_linear_schedule_with_warmup)
from torch.utils.data import Dataset
import pickle
import numpy as np
from collections import defaultdict
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
import os
import shutil
import subprocess
import json
import torch.nn.utils as F
from transformers import WEIGHTS_NAME
import glob
import pandas as pd
import torch.nn as nn

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Directories needed
qna_dataset_dir = "/content/drive/MyDrive/QnA"
qna_model_chkpts_dir =  "/content/drive/MyDrive/QnA/"
final_qna_model_dir = "/content/drive/MyDrive/QnA/final_QnA"

In [None]:
# Required Arguments
args_dir = {
  "save_steps" : 20, # can be changed
  "num_epochs" : 3,
  "gradient_accumulation_steps" : 10,
  "adam_epsilon" : 1e-8,
  "warmup_steps" : 0,
  "learning_rt" : 5e-5,
  "max_grad_norm" : 1.0,
  "data_dir" : qna_dataset_dir,
  "model_type" : "gpt2",
  "model_name" : "gpt2",  # set to gtp2-large
  "train_batch_size" : 5,
  "eval_batch_size" : 5,
  "extra_embedding_dim" : 768,
  "global_dense_feature_list" : None # in file it will be saved with the value null; while reading take care of this thing
}

model_type = args_dir["model_type"]
model_name = args_dir["model_name"]
data_dir = args_dir["data_dir"]
save_steps = args_dir["save_steps"]
num_epochs = args_dir["num_epochs"]
gradient_accumulation_steps = args_dir["gradient_accumulation_steps"]
adam_epsilon = args_dir["adam_epsilon"]
warmup_steps = args_dir["warmup_steps"]
train_batch_size = args_dir["train_batch_size"]
eval_batch_size = args_dir["eval_batch_size"]
learning_rt = args_dir["learning_rt"]
extra_embedding_dim = args_dir["extra_embedding_dim"] # Size of linear layer used for projecting extra embeddings.
global_dense_feature_list = args_dir["global_dense_feature_list"]
max_grad_norm = args_dir["max_grad_norm"]

In [None]:
# Choose device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpus = torch.cuda.device_count()

print("Device- ", device)
print("No. of GPUs- ", n_gpus)

Device-  cuda
No. of GPUs-  1


In [None]:
# Install Transformers
!pip install transformers



In [None]:
# Initialize model classes variables
MODEL_CLASSES = {
    'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
}
config_class, model_class, tokenizer_class = MODEL_CLASSES[model_type]

print("GPT2 Config class- ", config_class)
print("GPT2 Model class- ", model_class)
print("GPT2 Tokenizer class- ", tokenizer_class)

GPT2 Config class-  <class 'transformers.models.gpt2.configuration_gpt2.GPT2Config'>
GPT2 Model class-  <class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'>
GPT2 Tokenizer class-  <class 'transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer'>


In [None]:
# Init config
config = config_class.from_pretrained(model_name)
print("GPT2Config loaded")

# Init model
model = model_class.from_pretrained(model_name, config = config)
print("GPT2LMHeadModel loaded")

# Init tokenizer
tokenizer = tokenizer_class.from_pretrained(model_name,do_lower_case = False)
print("GPT2Tokenizer loaded")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

GPT2Config loaded


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

GPT2LMHeadModel loaded


vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

GPT2Tokenizer loaded


In [None]:
# adding extra_dimension to config --- No need in DVP ---- Can be removed later
config.extra_embedding_dim = extra_embedding_dim # don't know why we are using it; -- explore it

In [None]:
# Add special tokens to tokenizer
SPECIAL_TOKENS = {
    "additional_special_tokens": ["<segment_1>", "<segment_2>"],
    "pad_token": "<pad>",
    "bos_token": "<bos>",
    "eos_token": "<eos>"
}
tokenizer.add_special_tokens(SPECIAL_TOKENS)
print("Special Tokens addded to tokenizer")

print("Total tokens- ", len(tokenizer))

Special Tokens addded to tokenizer
Total tokens-  50262


In [None]:
# resize token embedding matrix to take care of special tokens added
model.resize_token_embeddings(len(tokenizer)) # each token of size-> 1280(gpt2-large), 768(gpt2)

Embedding(50262, 768)

In [None]:
# move model to device
model.to(device)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50262, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50262, bias=False)
)

In [None]:
# Init configs

MAX_PARAPHRASE_LEN = 100

# mainly to handle input
INPUT_FORMAT_CONFIG = {
    "keys": [
        {"key": "sent1_tokens", "position": 0},
        {"key": "sent2_tokens", "position": 1}
    ],
    "max_prefix_length": int(MAX_PARAPHRASE_LEN / 2),
    "max_suffix_length": int(MAX_PARAPHRASE_LEN / 2)
}

In [None]:
# Fn to convert example input to dictionary
def input_to_dict(config, sample, tokenizer):
    example = {}

    for inp_key in config["keys"]:
        val = sample[inp_key["position"]]
        example[inp_key["key"]] = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(val))

    return example

In [None]:
# Preprocess input from paranmt
def preprocess(exp, tokenizer, config, do_tokenize=True):
  MASK_TOKEN_ID = -100

  max_prefix_len = config["max_prefix_length"]
  max_suffix_len = config["max_suffix_length"]

  if do_tokenize:
    sent1 = np.array(exp["sent1_tokens"])
    sent2 = np.array(exp["sent2_tokens"])

  # truncate
  if(len(sent1) > max_prefix_len):
    sent1 = sent1[:max_prefix_len]

  if(len(sent2) > max_suffix_len):
    sent2 = sent2[:max_suffix_len]

  # add padding; left padding to prefix and right padding to suffix
  count_pad_tokens_prefix = max_prefix_len - len(sent1)
  sent1 = np.pad(sent1, (count_pad_tokens_prefix, 0), constant_values = tokenizer.pad_token_id)

  # add <eos> to suffix
  sent2 = np.append(sent2, tokenizer.eos_token_id)

  count_pad_tokens_suffix = (max_suffix_len + 1) - len(sent2)
  sent2 = np.pad(sent2, (0, count_pad_tokens_suffix), constant_values = tokenizer.pad_token_id)

  # sentence to input gpt2
  sentence_to_input_gpt2 = np.concatenate([sent1, [tokenizer.bos_token_id], sent2]).astype(np.int64) # [sent1, <bos> sent2]

  # label/gt to predict; -100 used for masking that input (in ground truth only)
  gt = np.concatenate([
      [MASK_TOKEN_ID for _ in sent1],
      [MASK_TOKEN_ID],
      [val if val != tokenizer.pad_token_id else MASK_TOKEN_ID for val in sent2]
  ]).astype(np.int64)

  # segment
  segment = np.concatenate([
      [tokenizer.additional_special_tokens_ids[0] for _ in sent1],
      [tokenizer.additional_special_tokens_ids[1]],
      [tokenizer.additional_special_tokens_ids[1] for _ in sent2]
  ]).astype(np.int64)

  exp["prefix_sent"] = sent1
  exp["suffix_sent"] = sent2

  exp["input"] = sentence_to_input_gpt2
  exp["label"] = gt
  exp["segment"] = segment

  return exp

In [None]:
# read csv file
df = pd.read_csv('/content/drive/MyDrive/QnA/train_orig.csv')
df.head()

q = list(df["Question"])
a = list(df["Answer"])

mix_question_answer = [(q[i], a[i]) for i in range(len(q))]

In [None]:
mix_question_answer_np = np.array(mix_question_answer)
np.random.shuffle(mix_question_answer_np)

In [None]:
type(mix_question_answer_np)

numpy.ndarray

In [None]:
mix_question_answer_np = mix_question_answer_np.tolist()

In [None]:
# t = (1,2)
# print(t[0])

1


In [None]:
from sklearn.model_selection import train_test_split

In [None]:
# split into train, validation and test
all_q = [mix_question_answer_np[i][0] for i in range(len(mix_question_answer_np))]
all_a = [mix_question_answer_np[i][1] for i in range(len(mix_question_answer_np))]


train_q, test_q, train_ans, test_ans = train_test_split(all_q, all_a, test_size=0.33, random_state=42)
train_q, val_q, train_ans, val_ans = train_test_split(train_q, train_ans, test_size=0.33, random_state=42)

In [None]:
val_q, test_q = test_q, val_q
val_ans, test_ans = test_ans, val_ans

In [None]:
print(len(train_q), len(test_q), len(val_q), sep="\n")
print(len(train_ans), len(test_ans), len(val_ans), sep="\n")

7364
3628
5415
7364
3628
5415


In [None]:
# QNA dataset
class QnA_Dataset(Dataset):
    def __init__(self, qna_dataset_dir, config, tokenizer, ques, ans, limit_examples = None, evaluate = False, split_type = "train"):
      self.config = config
      self.examples = []

      split_data = [(ques[i], ans[i]) for i in range(len(ques))]
      self.examples = [input_to_dict(self.config, sample, tokenizer) for sample in tqdm.tqdm(split_data)]

      print("\n\n After conversion- ", self.examples[0])

      # Reduce dataset if required
      if limit_examples != None:
        self.examples = self.examples[:limit_examples]

      print("\n\n Doing Preprocess each sample")
      # do Preprocessing in each of the converted samples
      self.examples = [preprocess(exp, tokenizer, self.config, do_tokenize = True) for exp in self.examples]

      print("\n\n After preprocessing- ", self.examples[0])

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
      sentence = self.examples[idx]["input"]
      label = self.examples[idx]["label"]
      segment = self.examples[idx]["segment"]
      context_len = self.config["max_prefix_length"] + 1 # (+1) for <bos>

      return {
          "sample_number": idx,
          "sentence": torch.tensor(sentence),
          "label": torch.tensor(label),
          "segment": torch.tensor(segment)
      }

In [None]:
# create DVP dataset
train_dataset = QnA_Dataset(qna_dataset_dir,
                                            INPUT_FORMAT_CONFIG,
                                            tokenizer,train_q, train_ans,
                                            limit_examples = None,
                                            evaluate = False, split_type = "train")
print("\n\n QnA Dataset created")

100%|██████████| 7364/7364 [00:31<00:00, 231.04it/s]




 After conversion-  {'sent1_tokens': [2061, 5640, 33636, 11880, 5633], 'sent2_tokens': [2061, 5640, 22987, 11880, 30, 1148, 340, 8513, 30, 383, 2748, 2728, 286, 22987, 11880, 468, 407, 587, 5174, 13, 2102, 11, 780, 262, 4006, 743, 307, 1944, 287, 1811, 1866, 286, 262, 976, 1641, 11, 25862, 743, 2620, 257, 1048, 338, 8395, 286, 5922, 262, 4006, 13, 317, 2050, 416, 1962, 320, 3565, 357, 12726, 8, 3751, 326, 257, 2176, 15304, 286, 257, 9779, 1444, 14639, 12, 16, 33, 357, 3849, 293, 2724, 259, 12, 16, 12159, 8, 318, 3917, 351, 281, 3220, 2526, 286, 5922, 22987, 11880, 290, 5644, 257, 8513, 4308, 329, 262, 2478, 286, 262, 4369, 13, 7735, 2267, 743, 1255, 287, 257, 1365, 4547, 286, 262, 8513, 16717, 2950, 287, 262, 2478, 286, 22987, 11880, 13]}


 Doing Preprocess each sample


 After preprocessing-  {'sent1_tokens': [2061, 5640, 33636, 11880, 5633], 'sent2_tokens': [2061, 5640, 22987, 11880, 30, 1148, 340, 8513, 30, 383, 2748, 2728, 286, 22987, 11880, 468, 407, 587, 5174, 13, 2102, 11, 78

In [None]:
# create dataloader
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler = train_sampler, batch_size = train_batch_size)

print("QnA train dataloader created")

QnA train dataloader created


In [None]:
# Total steps needed
t_total = len(train_dataloader) // gradient_accumulation_steps * num_epochs

# setting up the optimizer & learning rate schedulers
no_decay = ['bias', 'LayerNorm.weight', 'layer_norm.weight']
grouped_parameters = [
    {
        'params': [p for n, p in model.named_parameters()],
        'weight_decay': 0.0
    }
]

optimizer = AdamW(grouped_parameters, lr = float(learning_rt), eps = adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = warmup_steps, num_training_steps = t_total)

print("Adam Optimizer and learning rate scheduler instantiated")



Adam Optimizer and learning rate scheduler instantiated


In [None]:
# Training meta Information
print("Num of examples- ", len(train_dataset))
print("Num of epochs- ", num_epochs)
print("Batch size- ", train_batch_size)
print("Gradient acculmulation steps- ", gradient_accumulation_steps)
print("Total optimization steps- ", t_total)

Num of examples-  7364
Num of epochs-  3
Batch size-  5
Gradient acculmulation steps-  10
Total optimization steps-  441


In [None]:
# zero out all the gradients
model.zero_grad()

In [None]:
# Fn to save checkpoints
def save_model(model, tokenizer, chkpt_dir, global_step, args_dir):
  if not os.path.exists(chkpt_dir):
    os.makedirs(chkpt_dir)
  # print("Directory created for new checkpt to save")

  model.save_pretrained(chkpt_dir)
  tokenizer.save_pretrained(chkpt_dir)
  # print("Model and tokenizer saved")

  # save training arguments also
  with open(chkpt_dir + "/my_args.json", "w") as json_file:
    json.dump(args_dir, json_file)
  # print("Training arguments saved")

  with open(os.path.join(chkpt_dir, "global_step.txt"), "w") as f:
    f.write(str(global_step) + "\n")
  # print("Global step file saved")

  print("Checkpint saving process done..")

In [None]:
from tqdm import tqdm

global_step = 0
train_loss_val = 0.0
chkpts_dir_name = []

# start training
train_iterator = trange(int(num_epochs), desc = "Epoch")
for epoch in train_iterator:
    epoch_iterator = tqdm(train_dataloader)

    for batch_idx, batch in enumerate(epoch_iterator):
      sentences = batch["sentence"].to(device)
      labels = batch["label"].to(device)
      segments = batch["segment"].to(device)
      model.train()

      outputs = model(input_ids=sentences, token_type_ids=segments, labels=labels)
      # print("Got logits and loss")

      loss = outputs.loss
      loss = loss / gradient_accumulation_steps
      train_loss_val += loss.item()

      loss.backward()

      if (((batch_idx + 1) % gradient_accumulation_steps) == 0):
        # print("Moved 1 step")
        F.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()

        model.zero_grad()
        global_step += 1

        if (global_step % save_steps == 0):
          # save checkpoint here
          print("Saving new checkpoint")
          chkpt_dir = qna_model_chkpts_dir + "/qna_chkpt_"+str(global_step)
          chkpts_dir_name.append("qna_chkpt_"+str(global_step))

          save_model(model, tokenizer, chkpt_dir, global_step, args_dir)
      # break

Epoch:   0%|          | 0/3 [00:00<?, ?it/s]
  0%|          | 0/1473 [00:00<?, ?it/s][A
  0%|          | 1/1473 [00:02<59:43,  2.43s/it][A
  0%|          | 2/1473 [00:02<27:16,  1.11s/it][A
  0%|          | 3/1473 [00:02<16:20,  1.50it/s][A
  0%|          | 4/1473 [00:02<11:35,  2.11it/s][A
  0%|          | 5/1473 [00:03<08:38,  2.83it/s][A
  0%|          | 6/1473 [00:03<06:50,  3.57it/s][A
  0%|          | 7/1473 [00:03<05:50,  4.19it/s][A
  1%|          | 8/1473 [00:03<05:35,  4.36it/s][A
  1%|          | 9/1473 [00:03<04:42,  5.19it/s][A
  1%|          | 10/1473 [00:04<06:51,  3.56it/s][A
  1%|          | 11/1473 [00:04<05:31,  4.41it/s][A
  1%|          | 12/1473 [00:04<04:48,  5.06it/s][A
  1%|          | 13/1473 [00:04<04:23,  5.54it/s][A
  1%|          | 14/1473 [00:04<04:02,  6.02it/s][A
  1%|          | 15/1473 [00:04<03:51,  6.31it/s][A
  1%|          | 16/1473 [00:04<03:37,  6.71it/s][A
  1%|          | 17/1473 [00:05<03:31,  6.89it/s][A
  1%|          | 18

Saving new checkpoint



 14%|█▎        | 200/1473 [00:33<16:41,  1.27it/s][A

Checkpint saving process done..



 14%|█▎        | 202/1473 [00:33<10:23,  2.04it/s][A
 14%|█▍        | 203/1473 [00:34<08:35,  2.46it/s][A
 14%|█▍        | 204/1473 [00:34<07:04,  2.99it/s][A
 14%|█▍        | 205/1473 [00:34<05:54,  3.57it/s][A
 14%|█▍        | 206/1473 [00:34<05:02,  4.18it/s][A
 14%|█▍        | 207/1473 [00:34<04:24,  4.78it/s][A
 14%|█▍        | 208/1473 [00:34<03:56,  5.34it/s][A
 14%|█▍        | 209/1473 [00:34<03:37,  5.81it/s][A
 14%|█▍        | 210/1473 [00:35<04:03,  5.18it/s][A
 14%|█▍        | 212/1473 [00:35<03:13,  6.53it/s][A
 14%|█▍        | 213/1473 [00:35<03:06,  6.77it/s][A
 15%|█▍        | 214/1473 [00:35<03:02,  6.89it/s][A
 15%|█▍        | 215/1473 [00:35<02:58,  7.04it/s][A
 15%|█▍        | 216/1473 [00:35<02:58,  7.06it/s][A
 15%|█▍        | 217/1473 [00:36<02:56,  7.11it/s][A
 15%|█▍        | 218/1473 [00:36<02:57,  7.08it/s][A
 15%|█▍        | 219/1473 [00:36<02:55,  7.16it/s][A
 15%|█▍        | 220/1473 [00:36<03:27,  6.05it/s][A
 15%|█▌        | 222/1473 [

Saving new checkpoint



 27%|██▋       | 400/1473 [01:05<18:30,  1.03s/it][A

Checkpint saving process done..



 27%|██▋       | 402/1473 [01:05<11:06,  1.61it/s][A
 27%|██▋       | 403/1473 [01:05<09:01,  1.98it/s][A
 27%|██▋       | 404/1473 [01:05<07:18,  2.44it/s][A
 27%|██▋       | 405/1473 [01:06<05:59,  2.97it/s][A
 28%|██▊       | 406/1473 [01:06<04:59,  3.56it/s][A
 28%|██▊       | 407/1473 [01:06<04:15,  4.17it/s][A
 28%|██▊       | 408/1473 [01:06<03:43,  4.76it/s][A
 28%|██▊       | 409/1473 [01:06<03:22,  5.26it/s][A
 28%|██▊       | 410/1473 [01:06<03:36,  4.90it/s][A
 28%|██▊       | 412/1473 [01:07<02:49,  6.24it/s][A
 28%|██▊       | 413/1473 [01:07<02:42,  6.50it/s][A
 28%|██▊       | 414/1473 [01:07<02:38,  6.68it/s][A
 28%|██▊       | 415/1473 [01:07<02:34,  6.85it/s][A
 28%|██▊       | 416/1473 [01:07<02:33,  6.90it/s][A
 28%|██▊       | 417/1473 [01:07<02:32,  6.94it/s][A
 28%|██▊       | 418/1473 [01:07<02:30,  7.02it/s][A
 28%|██▊       | 419/1473 [01:08<02:30,  7.01it/s][A
 29%|██▊       | 420/1473 [01:08<03:01,  5.81it/s][A
 29%|██▊       | 422/1473 [

Saving new checkpoint



 41%|████      | 600/1473 [01:36<11:41,  1.24it/s][A

Checkpint saving process done..



 41%|████      | 602/1473 [01:37<07:14,  2.01it/s][A
 41%|████      | 603/1473 [01:37<05:59,  2.42it/s][A
 41%|████      | 604/1473 [01:37<04:57,  2.92it/s][A
 41%|████      | 605/1473 [01:37<04:09,  3.48it/s][A
 41%|████      | 606/1473 [01:37<03:33,  4.06it/s][A
 41%|████      | 607/1473 [01:37<03:06,  4.63it/s][A
 41%|████▏     | 608/1473 [01:38<02:49,  5.11it/s][A
 41%|████▏     | 609/1473 [01:38<02:37,  5.48it/s][A
 41%|████▏     | 610/1473 [01:38<02:53,  4.96it/s][A
 42%|████▏     | 612/1473 [01:38<02:17,  6.28it/s][A
 42%|████▏     | 613/1473 [01:38<02:12,  6.47it/s][A
 42%|████▏     | 614/1473 [01:38<02:09,  6.62it/s][A
 42%|████▏     | 615/1473 [01:39<02:08,  6.68it/s][A
 42%|████▏     | 616/1473 [01:39<02:07,  6.70it/s][A
 42%|████▏     | 617/1473 [01:39<02:06,  6.74it/s][A
 42%|████▏     | 618/1473 [01:39<02:06,  6.74it/s][A
 42%|████▏     | 619/1473 [01:39<02:04,  6.83it/s][A
 42%|████▏     | 620/1473 [01:39<02:28,  5.76it/s][A
 42%|████▏     | 622/1473 [

Saving new checkpoint



 54%|█████▍    | 800/1473 [02:09<08:49,  1.27it/s][A

Checkpint saving process done..



 54%|█████▍    | 802/1473 [02:10<05:29,  2.04it/s][A
 55%|█████▍    | 803/1473 [02:10<04:32,  2.46it/s][A
 55%|█████▍    | 804/1473 [02:10<03:46,  2.95it/s][A
 55%|█████▍    | 805/1473 [02:10<03:08,  3.54it/s][A
 55%|█████▍    | 806/1473 [02:10<02:45,  4.02it/s][A
 55%|█████▍    | 807/1473 [02:10<02:20,  4.74it/s][A
 55%|█████▍    | 808/1473 [02:11<02:07,  5.23it/s][A
 55%|█████▍    | 809/1473 [02:11<01:59,  5.56it/s][A
 55%|█████▍    | 810/1473 [02:11<02:11,  5.04it/s][A
 55%|█████▌    | 812/1473 [02:11<01:43,  6.38it/s][A
 55%|█████▌    | 813/1473 [02:11<01:41,  6.53it/s][A
 55%|█████▌    | 814/1473 [02:11<01:38,  6.69it/s][A
 55%|█████▌    | 815/1473 [02:12<01:37,  6.78it/s][A
 55%|█████▌    | 816/1473 [02:12<01:37,  6.75it/s][A
 55%|█████▌    | 817/1473 [02:12<01:37,  6.76it/s][A
 56%|█████▌    | 818/1473 [02:12<01:37,  6.71it/s][A
 56%|█████▌    | 819/1473 [02:12<01:34,  6.89it/s][A
 56%|█████▌    | 820/1473 [02:12<01:51,  5.86it/s][A
 56%|█████▌    | 822/1473 [

Saving new checkpoint



 68%|██████▊   | 1000/1473 [02:41<07:04,  1.12it/s][A

Checkpint saving process done..



 68%|██████▊   | 1002/1473 [02:42<04:17,  1.83it/s][A
 68%|██████▊   | 1003/1473 [02:42<03:29,  2.25it/s][A
 68%|██████▊   | 1004/1473 [02:42<02:51,  2.73it/s][A
 68%|██████▊   | 1005/1473 [02:42<02:23,  3.27it/s][A
 68%|██████▊   | 1006/1473 [02:42<02:00,  3.89it/s][A
 68%|██████▊   | 1007/1473 [02:42<01:44,  4.47it/s][A
 68%|██████▊   | 1008/1473 [02:42<01:34,  4.94it/s][A
 68%|██████▊   | 1009/1473 [02:43<01:27,  5.29it/s][A
 69%|██████▊   | 1010/1473 [02:43<01:34,  4.91it/s][A
 69%|██████▊   | 1012/1473 [02:43<01:16,  6.06it/s][A
 69%|██████▉   | 1013/1473 [02:43<01:12,  6.39it/s][A
 69%|██████▉   | 1014/1473 [02:43<01:10,  6.55it/s][A
 69%|██████▉   | 1015/1473 [02:43<01:09,  6.55it/s][A
 69%|██████▉   | 1016/1473 [02:44<01:08,  6.69it/s][A
 69%|██████▉   | 1017/1473 [02:44<01:07,  6.75it/s][A
 69%|██████▉   | 1018/1473 [02:44<01:07,  6.73it/s][A
 69%|██████▉   | 1019/1473 [02:44<01:06,  6.79it/s][A
 69%|██████▉   | 1020/1473 [02:44<01:17,  5.88it/s][A
 69%|████

Saving new checkpoint



 81%|████████▏ | 1200/1473 [03:14<04:49,  1.06s/it][A
 82%|████████▏ | 1201/1473 [03:14<03:34,  1.27it/s][A

Checkpint saving process done..



 82%|████████▏ | 1202/1473 [03:14<02:42,  1.66it/s][A
 82%|████████▏ | 1203/1473 [03:14<02:06,  2.14it/s][A
 82%|████████▏ | 1204/1473 [03:14<01:40,  2.68it/s][A
 82%|████████▏ | 1205/1473 [03:15<01:22,  3.25it/s][A
 82%|████████▏ | 1206/1473 [03:15<01:08,  3.93it/s][A
 82%|████████▏ | 1207/1473 [03:15<01:00,  4.38it/s][A
 82%|████████▏ | 1208/1473 [03:15<00:56,  4.71it/s][A
 82%|████████▏ | 1209/1473 [03:15<00:49,  5.33it/s][A
 82%|████████▏ | 1210/1473 [03:15<00:53,  4.92it/s][A
 82%|████████▏ | 1211/1473 [03:16<00:45,  5.73it/s][A
 82%|████████▏ | 1212/1473 [03:16<00:42,  6.14it/s][A
 82%|████████▏ | 1213/1473 [03:16<00:41,  6.31it/s][A
 82%|████████▏ | 1214/1473 [03:16<00:40,  6.35it/s][A
 82%|████████▏ | 1215/1473 [03:16<00:39,  6.53it/s][A
 83%|████████▎ | 1216/1473 [03:16<00:39,  6.43it/s][A
 83%|████████▎ | 1217/1473 [03:16<00:39,  6.47it/s][A
 83%|████████▎ | 1218/1473 [03:17<00:38,  6.59it/s][A
 83%|████████▎ | 1219/1473 [03:17<00:40,  6.27it/s][A
 83%|████

Saving new checkpoint



 95%|█████████▌| 1400/1473 [03:47<01:14,  1.02s/it][A
 95%|█████████▌| 1401/1473 [03:47<00:54,  1.33it/s][A

Checkpint saving process done..



 95%|█████████▌| 1402/1473 [03:47<00:41,  1.72it/s][A
 95%|█████████▌| 1403/1473 [03:47<00:31,  2.23it/s][A
 95%|█████████▌| 1404/1473 [03:47<00:24,  2.81it/s][A
 95%|█████████▌| 1405/1473 [03:47<00:19,  3.42it/s][A
 95%|█████████▌| 1406/1473 [03:47<00:16,  4.06it/s][A
 96%|█████████▌| 1407/1473 [03:48<00:14,  4.66it/s][A
 96%|█████████▌| 1408/1473 [03:48<00:12,  5.13it/s][A
 96%|█████████▌| 1409/1473 [03:48<00:11,  5.45it/s][A
 96%|█████████▌| 1410/1473 [03:48<00:12,  5.00it/s][A
 96%|█████████▌| 1412/1473 [03:48<00:09,  6.29it/s][A
 96%|█████████▌| 1413/1473 [03:48<00:09,  6.46it/s][A
 96%|█████████▌| 1414/1473 [03:49<00:08,  6.56it/s][A
 96%|█████████▌| 1415/1473 [03:49<00:08,  6.57it/s][A
 96%|█████████▌| 1416/1473 [03:49<00:08,  6.64it/s][A
 96%|█████████▌| 1417/1473 [03:49<00:08,  6.70it/s][A
 96%|█████████▋| 1418/1473 [03:49<00:08,  6.78it/s][A
 96%|█████████▋| 1419/1473 [03:49<00:07,  6.81it/s][A
 96%|█████████▋| 1420/1473 [03:50<00:09,  5.81it/s][A
 97%|████

Saving new checkpoint



  9%|▉         | 130/1473 [00:21<21:06,  1.06it/s][A

Checkpint saving process done..



  9%|▉         | 132/1473 [00:22<12:54,  1.73it/s][A
  9%|▉         | 133/1473 [00:22<10:26,  2.14it/s][A
  9%|▉         | 134/1473 [00:22<08:31,  2.62it/s][A
  9%|▉         | 135/1473 [00:22<07:03,  3.16it/s][A
  9%|▉         | 136/1473 [00:22<05:58,  3.73it/s][A
  9%|▉         | 137/1473 [00:22<05:09,  4.32it/s][A
  9%|▉         | 138/1473 [00:22<04:36,  4.83it/s][A
  9%|▉         | 139/1473 [00:23<04:16,  5.20it/s][A
 10%|▉         | 140/1473 [00:23<04:37,  4.81it/s][A
 10%|▉         | 142/1473 [00:23<03:35,  6.16it/s][A
 10%|▉         | 143/1473 [00:23<03:29,  6.34it/s][A
 10%|▉         | 144/1473 [00:23<03:21,  6.58it/s][A
 10%|▉         | 145/1473 [00:23<03:19,  6.66it/s][A
 10%|▉         | 146/1473 [00:24<03:18,  6.70it/s][A
 10%|▉         | 147/1473 [00:24<03:17,  6.72it/s][A
 10%|█         | 148/1473 [00:24<03:14,  6.82it/s][A
 10%|█         | 149/1473 [00:24<03:11,  6.92it/s][A
 10%|█         | 150/1473 [00:24<03:49,  5.77it/s][A
 10%|█         | 152/1473 [

Saving new checkpoint



 22%|██▏       | 330/1473 [00:54<19:38,  1.03s/it][A

Checkpint saving process done..



 23%|██▎       | 332/1473 [00:54<11:51,  1.60it/s][A
 23%|██▎       | 333/1473 [00:54<09:35,  1.98it/s][A
 23%|██▎       | 334/1473 [00:54<07:47,  2.44it/s][A
 23%|██▎       | 335/1473 [00:55<06:22,  2.97it/s][A
 23%|██▎       | 336/1473 [00:55<05:19,  3.56it/s][A
 23%|██▎       | 337/1473 [00:55<04:34,  4.14it/s][A
 23%|██▎       | 338/1473 [00:55<04:00,  4.71it/s][A
 23%|██▎       | 339/1473 [00:55<03:42,  5.10it/s][A
 23%|██▎       | 340/1473 [00:55<03:59,  4.73it/s][A
 23%|██▎       | 342/1473 [00:56<03:06,  6.07it/s][A
 23%|██▎       | 343/1473 [00:56<02:59,  6.29it/s][A
 23%|██▎       | 344/1473 [00:56<02:54,  6.45it/s][A
 23%|██▎       | 345/1473 [00:56<02:50,  6.61it/s][A
 23%|██▎       | 346/1473 [00:56<02:50,  6.60it/s][A
 24%|██▎       | 347/1473 [00:56<02:49,  6.64it/s][A
 24%|██▎       | 348/1473 [00:56<02:44,  6.83it/s][A
 24%|██▎       | 349/1473 [00:57<02:41,  6.96it/s][A
 24%|██▍       | 350/1473 [00:57<03:13,  5.81it/s][A
 24%|██▍       | 352/1473 [

Saving new checkpoint



 36%|███▌      | 530/1473 [01:26<13:22,  1.18it/s][A

Checkpint saving process done..



 36%|███▌      | 532/1473 [01:26<08:08,  1.93it/s][A
 36%|███▌      | 533/1473 [01:26<06:39,  2.35it/s][A
 36%|███▋      | 534/1473 [01:26<05:29,  2.85it/s][A
 36%|███▋      | 535/1473 [01:26<04:37,  3.38it/s][A
 36%|███▋      | 536/1473 [01:27<03:55,  3.98it/s][A
 36%|███▋      | 537/1473 [01:27<03:25,  4.55it/s][A
 37%|███▋      | 538/1473 [01:27<03:06,  5.01it/s][A
 37%|███▋      | 539/1473 [01:27<02:53,  5.40it/s][A
 37%|███▋      | 540/1473 [01:27<03:09,  4.93it/s][A
 37%|███▋      | 542/1473 [01:27<02:29,  6.22it/s][A
 37%|███▋      | 543/1473 [01:28<02:24,  6.42it/s][A
 37%|███▋      | 544/1473 [01:28<02:21,  6.58it/s][A
 37%|███▋      | 545/1473 [01:28<02:20,  6.60it/s][A
 37%|███▋      | 546/1473 [01:28<02:19,  6.65it/s][A
 37%|███▋      | 547/1473 [01:28<02:17,  6.73it/s][A
 37%|███▋      | 548/1473 [01:28<02:14,  6.89it/s][A
 37%|███▋      | 549/1473 [01:28<02:16,  6.79it/s][A
 37%|███▋      | 550/1473 [01:29<02:37,  5.85it/s][A
 37%|███▋      | 552/1473 [

Saving new checkpoint



 50%|████▉     | 730/1473 [01:58<10:08,  1.22it/s][A

Checkpint saving process done..



 50%|████▉     | 732/1473 [01:58<06:15,  1.98it/s][A
 50%|████▉     | 733/1473 [01:58<05:07,  2.41it/s][A
 50%|████▉     | 734/1473 [01:58<04:14,  2.90it/s][A
 50%|████▉     | 735/1473 [01:58<03:32,  3.47it/s][A
 50%|████▉     | 736/1473 [01:58<03:02,  4.03it/s][A
 50%|█████     | 737/1473 [01:59<02:39,  4.60it/s][A
 50%|█████     | 738/1473 [01:59<02:24,  5.07it/s][A
 50%|█████     | 739/1473 [01:59<02:15,  5.42it/s][A
 50%|█████     | 740/1473 [01:59<02:28,  4.94it/s][A
 50%|█████     | 742/1473 [01:59<01:57,  6.23it/s][A
 50%|█████     | 743/1473 [02:00<01:54,  6.39it/s][A
 51%|█████     | 744/1473 [02:00<01:51,  6.55it/s][A
 51%|█████     | 745/1473 [02:00<01:51,  6.55it/s][A
 51%|█████     | 746/1473 [02:00<01:49,  6.65it/s][A
 51%|█████     | 747/1473 [02:00<01:48,  6.70it/s][A
 51%|█████     | 748/1473 [02:00<01:46,  6.78it/s][A
 51%|█████     | 749/1473 [02:00<01:45,  6.87it/s][A
 51%|█████     | 750/1473 [02:01<02:06,  5.73it/s][A
 51%|█████     | 752/1473 [

Saving new checkpoint



 63%|██████▎   | 930/1473 [02:30<08:10,  1.11it/s][A

Checkpint saving process done..



 63%|██████▎   | 932/1473 [02:30<04:59,  1.80it/s][A
 63%|██████▎   | 933/1473 [02:30<04:04,  2.21it/s][A
 63%|██████▎   | 934/1473 [02:30<03:20,  2.69it/s][A
 63%|██████▎   | 935/1473 [02:30<02:45,  3.24it/s][A
 64%|██████▎   | 936/1473 [02:31<02:20,  3.82it/s][A
 64%|██████▎   | 937/1473 [02:31<02:02,  4.37it/s][A
 64%|██████▎   | 938/1473 [02:31<01:48,  4.94it/s][A
 64%|██████▎   | 939/1473 [02:31<01:41,  5.27it/s][A
 64%|██████▍   | 940/1473 [02:31<01:50,  4.84it/s][A
 64%|██████▍   | 942/1473 [02:31<01:26,  6.17it/s][A
 64%|██████▍   | 943/1473 [02:32<01:23,  6.38it/s][A
 64%|██████▍   | 944/1473 [02:32<01:20,  6.54it/s][A
 64%|██████▍   | 945/1473 [02:32<01:19,  6.61it/s][A
 64%|██████▍   | 946/1473 [02:32<01:19,  6.62it/s][A
 64%|██████▍   | 947/1473 [02:32<01:18,  6.72it/s][A
 64%|██████▍   | 948/1473 [02:32<01:16,  6.85it/s][A
 64%|██████▍   | 949/1473 [02:32<01:15,  6.94it/s][A
 64%|██████▍   | 950/1473 [02:33<01:29,  5.82it/s][A
 65%|██████▍   | 952/1473 [

Saving new checkpoint



 77%|███████▋  | 1130/1473 [03:02<04:50,  1.18it/s][A

Checkpint saving process done..



 77%|███████▋  | 1132/1473 [03:02<02:56,  1.93it/s][A
 77%|███████▋  | 1133/1473 [03:02<02:24,  2.35it/s][A
 77%|███████▋  | 1134/1473 [03:02<01:58,  2.86it/s][A
 77%|███████▋  | 1135/1473 [03:02<01:39,  3.41it/s][A
 77%|███████▋  | 1136/1473 [03:02<01:24,  3.99it/s][A
 77%|███████▋  | 1137/1473 [03:03<01:13,  4.58it/s][A
 77%|███████▋  | 1138/1473 [03:03<01:06,  5.06it/s][A
 77%|███████▋  | 1139/1473 [03:03<01:02,  5.37it/s][A
 77%|███████▋  | 1140/1473 [03:03<01:06,  4.99it/s][A
 78%|███████▊  | 1142/1473 [03:03<00:52,  6.30it/s][A
 78%|███████▊  | 1143/1473 [03:03<00:51,  6.45it/s][A
 78%|███████▊  | 1144/1473 [03:04<00:50,  6.48it/s][A
 78%|███████▊  | 1145/1473 [03:04<00:50,  6.52it/s][A
 78%|███████▊  | 1146/1473 [03:04<00:50,  6.50it/s][A
 78%|███████▊  | 1147/1473 [03:04<00:48,  6.73it/s][A
 78%|███████▊  | 1148/1473 [03:04<00:47,  6.79it/s][A
 78%|███████▊  | 1149/1473 [03:04<00:47,  6.81it/s][A
 78%|███████▊  | 1150/1473 [03:05<00:55,  5.83it/s][A
 78%|████

Saving new checkpoint



 90%|█████████ | 1330/1473 [03:34<02:27,  1.03s/it][A

Checkpint saving process done..



 90%|█████████ | 1332/1473 [03:34<01:27,  1.61it/s][A
 90%|█████████ | 1333/1473 [03:35<01:10,  1.98it/s][A
 91%|█████████ | 1334/1473 [03:35<00:56,  2.44it/s][A
 91%|█████████ | 1335/1473 [03:35<00:46,  2.96it/s][A
 91%|█████████ | 1336/1473 [03:35<00:39,  3.51it/s][A
 91%|█████████ | 1337/1473 [03:35<00:33,  4.11it/s][A
 91%|█████████ | 1338/1473 [03:35<00:28,  4.68it/s][A
 91%|█████████ | 1339/1473 [03:35<00:26,  5.10it/s][A
 91%|█████████ | 1340/1473 [03:36<00:27,  4.78it/s][A
 91%|█████████ | 1342/1473 [03:36<00:21,  6.04it/s][A
 91%|█████████ | 1343/1473 [03:36<00:21,  6.13it/s][A
 91%|█████████ | 1344/1473 [03:36<00:20,  6.45it/s][A
 91%|█████████▏| 1345/1473 [03:36<00:19,  6.62it/s][A
 91%|█████████▏| 1346/1473 [03:36<00:19,  6.60it/s][A
 91%|█████████▏| 1347/1473 [03:37<00:18,  6.71it/s][A
 92%|█████████▏| 1348/1473 [03:37<00:18,  6.73it/s][A
 92%|█████████▏| 1349/1473 [03:37<00:18,  6.78it/s][A
 92%|█████████▏| 1350/1473 [03:37<00:21,  5.76it/s][A
 92%|████

Saving new checkpoint



  4%|▍         | 60/1473 [00:11<23:47,  1.01s/it][A

Checkpint saving process done..



  4%|▍         | 62/1473 [00:11<14:24,  1.63it/s][A
  4%|▍         | 63/1473 [00:12<11:40,  2.01it/s][A
  4%|▍         | 64/1473 [00:12<09:28,  2.48it/s][A
  4%|▍         | 65/1473 [00:12<07:47,  3.01it/s][A
  4%|▍         | 66/1473 [00:12<06:32,  3.59it/s][A
  5%|▍         | 67/1473 [00:12<05:38,  4.15it/s][A
  5%|▍         | 68/1473 [00:12<04:56,  4.75it/s][A
  5%|▍         | 69/1473 [00:12<04:37,  5.07it/s][A
  5%|▍         | 70/1473 [00:13<04:52,  4.80it/s][A
  5%|▍         | 72/1473 [00:13<03:48,  6.13it/s][A
  5%|▍         | 73/1473 [00:13<03:42,  6.30it/s][A
  5%|▌         | 74/1473 [00:13<03:35,  6.49it/s][A
  5%|▌         | 75/1473 [00:13<03:36,  6.47it/s][A
  5%|▌         | 76/1473 [00:13<03:30,  6.62it/s][A
  5%|▌         | 77/1473 [00:14<03:27,  6.73it/s][A
  5%|▌         | 78/1473 [00:14<03:25,  6.80it/s][A
  5%|▌         | 79/1473 [00:14<03:23,  6.84it/s][A
  5%|▌         | 80/1473 [00:14<04:03,  5.72it/s][A
  6%|▌         | 82/1473 [00:14<03:24,  6.81i

Saving new checkpoint



 18%|█▊        | 260/1473 [00:44<22:21,  1.11s/it][A

Checkpint saving process done..



 18%|█▊        | 262/1473 [00:44<13:24,  1.51it/s][A
 18%|█▊        | 263/1473 [00:45<10:48,  1.87it/s][A
 18%|█▊        | 264/1473 [00:45<08:41,  2.32it/s][A
 18%|█▊        | 265/1473 [00:45<07:05,  2.84it/s][A
 18%|█▊        | 266/1473 [00:45<05:54,  3.41it/s][A
 18%|█▊        | 267/1473 [00:45<05:01,  4.00it/s][A
 18%|█▊        | 268/1473 [00:45<04:22,  4.60it/s][A
 18%|█▊        | 269/1473 [00:45<04:02,  4.97it/s][A
 18%|█▊        | 270/1473 [00:46<04:18,  4.66it/s][A
 18%|█▊        | 272/1473 [00:46<03:19,  6.01it/s][A
 19%|█▊        | 273/1473 [00:46<03:11,  6.27it/s][A
 19%|█▊        | 274/1473 [00:46<03:04,  6.48it/s][A
 19%|█▊        | 275/1473 [00:46<03:01,  6.60it/s][A
 19%|█▊        | 276/1473 [00:46<02:59,  6.66it/s][A
 19%|█▉        | 277/1473 [00:47<02:59,  6.65it/s][A
 19%|█▉        | 278/1473 [00:47<02:56,  6.77it/s][A
 19%|█▉        | 279/1473 [00:47<02:53,  6.89it/s][A
 19%|█▉        | 280/1473 [00:47<03:26,  5.77it/s][A
 19%|█▉        | 282/1473 [

Saving new checkpoint



 31%|███       | 460/1473 [01:16<14:36,  1.16it/s][A

Checkpint saving process done..



 31%|███▏      | 462/1473 [01:16<08:56,  1.88it/s][A
 31%|███▏      | 463/1473 [01:16<07:18,  2.30it/s][A
 32%|███▏      | 464/1473 [01:16<05:59,  2.80it/s][A
 32%|███▏      | 465/1473 [01:17<05:00,  3.35it/s][A
 32%|███▏      | 466/1473 [01:17<04:15,  3.94it/s][A
 32%|███▏      | 467/1473 [01:17<03:43,  4.50it/s][A
 32%|███▏      | 468/1473 [01:17<03:19,  5.03it/s][A
 32%|███▏      | 469/1473 [01:17<03:06,  5.40it/s][A
 32%|███▏      | 470/1473 [01:17<03:22,  4.95it/s][A
 32%|███▏      | 472/1473 [01:18<02:39,  6.28it/s][A
 32%|███▏      | 473/1473 [01:18<02:34,  6.48it/s][A
 32%|███▏      | 474/1473 [01:18<02:31,  6.59it/s][A
 32%|███▏      | 475/1473 [01:18<02:30,  6.62it/s][A
 32%|███▏      | 476/1473 [01:18<02:29,  6.67it/s][A
 32%|███▏      | 477/1473 [01:18<02:26,  6.78it/s][A
 32%|███▏      | 478/1473 [01:19<02:23,  6.94it/s][A
 33%|███▎      | 479/1473 [01:19<02:23,  6.94it/s][A
 33%|███▎      | 480/1473 [01:19<02:51,  5.79it/s][A
 33%|███▎      | 482/1473 [

Saving new checkpoint



 45%|████▍     | 660/1473 [01:52<26:35,  1.96s/it][A

Checkpint saving process done..



 45%|████▍     | 662/1473 [01:52<15:16,  1.13s/it][A
 45%|████▌     | 663/1473 [01:52<11:59,  1.13it/s][A
 45%|████▌     | 664/1473 [01:52<09:20,  1.44it/s][A
 45%|████▌     | 665/1473 [01:52<07:18,  1.84it/s][A
 45%|████▌     | 666/1473 [01:52<05:46,  2.33it/s][A
 45%|████▌     | 667/1473 [01:53<04:39,  2.88it/s][A
 45%|████▌     | 668/1473 [01:53<03:51,  3.48it/s][A
 45%|████▌     | 669/1473 [01:53<03:19,  4.04it/s][A
 45%|████▌     | 670/1473 [01:53<03:17,  4.07it/s][A
 46%|████▌     | 672/1473 [01:53<02:26,  5.47it/s][A
 46%|████▌     | 673/1473 [01:53<02:17,  5.82it/s][A
 46%|████▌     | 674/1473 [01:54<02:10,  6.11it/s][A
 46%|████▌     | 675/1473 [01:54<02:05,  6.36it/s][A
 46%|████▌     | 676/1473 [01:54<02:02,  6.50it/s][A
 46%|████▌     | 677/1473 [01:54<02:00,  6.60it/s][A
 46%|████▌     | 678/1473 [01:54<01:59,  6.66it/s][A
 46%|████▌     | 679/1473 [01:54<01:56,  6.84it/s][A
 46%|████▌     | 680/1473 [01:55<02:18,  5.72it/s][A
 46%|████▋     | 682/1473 [

Saving new checkpoint



 58%|█████▊    | 860/1473 [02:24<10:28,  1.02s/it][A

Checkpint saving process done..



 59%|█████▊    | 862/1473 [02:24<06:15,  1.63it/s][A
 59%|█████▊    | 863/1473 [02:25<05:05,  2.00it/s][A
 59%|█████▊    | 864/1473 [02:25<04:07,  2.46it/s][A
 59%|█████▊    | 865/1473 [02:25<03:23,  2.99it/s][A
 59%|█████▉    | 866/1473 [02:25<02:50,  3.56it/s][A
 59%|█████▉    | 867/1473 [02:25<02:26,  4.13it/s][A
 59%|█████▉    | 868/1473 [02:25<02:08,  4.70it/s][A
 59%|█████▉    | 869/1473 [02:25<02:00,  5.02it/s][A
 59%|█████▉    | 870/1473 [02:26<02:07,  4.75it/s][A
 59%|█████▉    | 872/1473 [02:26<01:39,  6.01it/s][A
 59%|█████▉    | 873/1473 [02:26<01:36,  6.21it/s][A
 59%|█████▉    | 874/1473 [02:26<01:34,  6.37it/s][A
 59%|█████▉    | 875/1473 [02:26<01:33,  6.36it/s][A
 59%|█████▉    | 876/1473 [02:26<01:32,  6.49it/s][A
 60%|█████▉    | 877/1473 [02:27<01:30,  6.60it/s][A
 60%|█████▉    | 878/1473 [02:27<01:31,  6.47it/s][A
 60%|█████▉    | 879/1473 [02:27<01:28,  6.75it/s][A
 60%|█████▉    | 880/1473 [02:27<01:43,  5.73it/s][A
 60%|█████▉    | 882/1473 [

Saving new checkpoint



 72%|███████▏  | 1060/1473 [02:56<05:35,  1.23it/s][A
 72%|███████▏  | 1061/1473 [02:56<04:08,  1.66it/s][A

Checkpint saving process done..



 72%|███████▏  | 1062/1473 [02:56<03:16,  2.09it/s][A
 72%|███████▏  | 1063/1473 [02:56<02:37,  2.60it/s][A
 72%|███████▏  | 1064/1473 [02:56<02:05,  3.25it/s][A
 72%|███████▏  | 1065/1473 [02:56<01:45,  3.89it/s][A
 72%|███████▏  | 1066/1473 [02:57<01:31,  4.47it/s][A
 72%|███████▏  | 1067/1473 [02:57<01:20,  5.06it/s][A
 73%|███████▎  | 1068/1473 [02:57<01:13,  5.51it/s][A
 73%|███████▎  | 1069/1473 [02:57<01:10,  5.76it/s][A
 73%|███████▎  | 1070/1473 [02:57<01:18,  5.12it/s][A
 73%|███████▎  | 1072/1473 [02:57<01:03,  6.33it/s][A
 73%|███████▎  | 1073/1473 [02:58<01:01,  6.53it/s][A
 73%|███████▎  | 1074/1473 [02:58<00:59,  6.67it/s][A
 73%|███████▎  | 1075/1473 [02:58<00:58,  6.76it/s][A
 73%|███████▎  | 1076/1473 [02:58<00:58,  6.75it/s][A
 73%|███████▎  | 1077/1473 [02:58<00:58,  6.77it/s][A
 73%|███████▎  | 1078/1473 [02:58<00:57,  6.81it/s][A
 73%|███████▎  | 1079/1473 [02:58<00:57,  6.91it/s][A
 73%|███████▎  | 1080/1473 [02:59<01:06,  5.88it/s][A
 73%|████

Saving new checkpoint



 86%|████████▌ | 1260/1473 [03:28<03:30,  1.01it/s][A

Checkpint saving process done..



 86%|████████▌ | 1262/1473 [03:28<02:06,  1.66it/s][A
 86%|████████▌ | 1263/1473 [03:28<01:43,  2.03it/s][A
 86%|████████▌ | 1264/1473 [03:28<01:23,  2.51it/s][A
 86%|████████▌ | 1265/1473 [03:29<01:08,  3.05it/s][A
 86%|████████▌ | 1266/1473 [03:29<00:57,  3.62it/s][A
 86%|████████▌ | 1267/1473 [03:29<00:49,  4.19it/s][A
 86%|████████▌ | 1268/1473 [03:29<00:42,  4.78it/s][A
 86%|████████▌ | 1269/1473 [03:29<00:39,  5.14it/s][A
 86%|████████▌ | 1270/1473 [03:29<00:42,  4.82it/s][A
 86%|████████▋ | 1272/1473 [03:30<00:32,  6.12it/s][A
 86%|████████▋ | 1273/1473 [03:30<00:31,  6.31it/s][A
 86%|████████▋ | 1274/1473 [03:30<00:30,  6.49it/s][A
 87%|████████▋ | 1275/1473 [03:30<00:30,  6.53it/s][A
 87%|████████▋ | 1276/1473 [03:30<00:31,  6.34it/s][A
 87%|████████▋ | 1277/1473 [03:30<00:29,  6.55it/s][A
 87%|████████▋ | 1278/1473 [03:31<00:29,  6.57it/s][A
 87%|████████▋ | 1279/1473 [03:31<00:29,  6.65it/s][A
 87%|████████▋ | 1280/1473 [03:31<00:34,  5.65it/s][A
 87%|████

Saving new checkpoint



 99%|█████████▉| 1460/1473 [04:02<00:15,  1.18s/it][A

Checkpint saving process done..



 99%|█████████▉| 1462/1473 [04:02<00:07,  1.42it/s][A
 99%|█████████▉| 1463/1473 [04:02<00:05,  1.77it/s][A
 99%|█████████▉| 1464/1473 [04:03<00:04,  2.19it/s][A
 99%|█████████▉| 1465/1473 [04:03<00:02,  2.70it/s][A
100%|█████████▉| 1466/1473 [04:03<00:02,  3.29it/s][A
100%|█████████▉| 1467/1473 [04:03<00:01,  3.91it/s][A
100%|█████████▉| 1468/1473 [04:03<00:01,  4.50it/s][A
100%|█████████▉| 1469/1473 [04:03<00:00,  4.97it/s][A
100%|█████████▉| 1470/1473 [04:03<00:00,  4.66it/s][A
100%|█████████▉| 1472/1473 [04:04<00:00,  5.98it/s][A
100%|██████████| 1473/1473 [04:04<00:00,  6.03it/s]
Epoch: 100%|██████████| 3/3 [11:58<00:00, 239.39s/it]


In [None]:
# Average train_loss per step
global_step, tr_loss = global_step, train_loss_val / (global_step) # +1 only while testing
print("Final Global step- ", global_step)
print("Average training loss per step- ", tr_loss)

Final Global step-  441
Average training loss per step-  2.2707876763359356


In [None]:
# save the last model also
# global_step = 3  # remove this also; fetch the last global step value
chkpt_dir = qna_model_chkpts_dir + "/qna_chkpt_"+str(global_step)
chkpts_dir_name.append("qna_chkpt_"+str(global_step))

# save_model(model, tokenizer, chkpt_dir, args_dir)
save_model(model, tokenizer, chkpt_dir, global_step, args_dir)
print("Last model state saved")

Checkpint saving process done..
Last model state saved


In [None]:
# Till now,
# QnA trained -> checkpoints saved -> Last model state saved
print("Checkpoints saved with the name- ", chkpts_dir_name)

Checkpoints saved with the name-  ['qna_chkpt_20', 'qna_chkpt_40', 'qna_chkpt_60', 'qna_chkpt_80', 'qna_chkpt_100', 'qna_chkpt_120', 'qna_chkpt_140', 'qna_chkpt_160', 'qna_chkpt_180', 'qna_chkpt_200', 'qna_chkpt_220', 'qna_chkpt_240', 'qna_chkpt_260', 'qna_chkpt_280', 'qna_chkpt_300', 'qna_chkpt_320', 'qna_chkpt_340', 'qna_chkpt_360', 'qna_chkpt_380', 'qna_chkpt_400', 'qna_chkpt_420', 'qna_chkpt_440', 'qna_chkpt_441']


In [None]:
# Start Evaluation
print("Starting Evaluation of QnA on dev data based on perplexity")

Starting Evaluation of QnA on dev data based on perplexity


In [None]:
# Get validation dataset and dataloader
import tqdm
val_dataset = QnA_Dataset(qna_dataset_dir,
                                            INPUT_FORMAT_CONFIG,
                                            tokenizer,val_q, val_ans,
                                            limit_examples = None,
                                            evaluate = True, split_type = "dev")

val_sampler = SequentialSampler(val_dataset)
val_dataloader = DataLoader(val_dataset, sampler = val_sampler, batch_size = eval_batch_size)

print("Validation dataset and dataloader created")

100%|██████████| 5415/5415 [00:15<00:00, 355.85it/s]




 After conversion-  {'sent1_tokens': [2061, 389, 262, 13820, 329, 2908, 268, 1287, 2011, 27189, 5633], 'sent2_tokens': [21327, 11, 691, 4318, 4755, 4369, 468, 281, 4050, 3513, 357, 3826, 2029, 737, 1318, 389, 645, 1900, 50019, 329, 597, 286, 777, 11916, 13, 7929, 425, 3513, 743, 6211, 29617, 19458, 291, 13820, 11, 355, 880, 355, 3518, 11, 34266, 393, 4046, 9102, 13]}


 Doing Preprocess each sample


 After preprocessing-  {'sent1_tokens': [2061, 389, 262, 13820, 329, 2908, 268, 1287, 2011, 27189, 5633], 'sent2_tokens': [21327, 11, 691, 4318, 4755, 4369, 468, 281, 4050, 3513, 357, 3826, 2029, 737, 1318, 389, 645, 1900, 50019, 329, 597, 286, 777, 11916, 13, 7929, 425, 3513, 743, 6211, 29617, 19458, 291, 13820, 11, 355, 880, 355, 3518, 11, 34266, 393, 4046, 9102, 13], 'prefix_sent': array([50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
       50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
       50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 502

In [None]:
# Validation meta Information
print("Num of examples- ", len(val_dataset))
print("Batch size- ", eval_batch_size)

Num of examples-  5415
Batch size-  5


In [None]:
# Fn to evaluate on a DVP checkpoint
def evaluate(model, tokenizer, chkpt_dir_name, val_dataloader):
  val_loss = 0.0
  model.eval()

  for i, batch in enumerate(val_dataloader):
    sentences = batch["sentence"].to(device)
    labels = batch["label"].to(device)
    segments = batch["segment"].to(device)

    with torch.no_grad():
      op = model(input_ids=sentences, token_type_ids=segments, labels=labels)
      loss_val = op.loss.item()

    val_loss += loss_val
    # break

  avg_val_loss = val_loss / (i + 1) # per batch average loss
  perplexity = torch.exp(torch.tensor(avg_val_loss)) # perplexity of exp(avg_loss)

  return perplexity

In [None]:
chkpts_dir_name = chkpts_dir_name[2:]

In [None]:
# Start evaluating the checkpoints on dev data and using perplexity as a measure
perplexity_list = []

for chkpt_name in chkpts_dir_name:

  print("Evaluating ", chkpt_name)
  chkpt_to_load = qna_model_chkpts_dir + chkpt_name
  model = model_class.from_pretrained(chkpt_to_load)
  tokenizer = tokenizer_class.from_pretrained(chkpt_to_load, do_lower_case = True)
  model.to(device)

  print("Checkpoint- " + chkpt_name + " loaded")

  # evaluate loaded
  print("Evaluating on loaded checkpoint")
  perplexity = evaluate(model, tokenizer, chkpt_dir, val_dataloader)
  perplexity_list.append((chkpt_name, perplexity))

print("QnA evaluated on all the saved checkpoints")

Evaluating  qna_chkpt_60


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Checkpoint- qna_chkpt_60 loaded
Evaluating on loaded checkpoint
Evaluating  qna_chkpt_80


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Checkpoint- qna_chkpt_80 loaded
Evaluating on loaded checkpoint
Evaluating  qna_chkpt_100


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Checkpoint- qna_chkpt_100 loaded
Evaluating on loaded checkpoint
Evaluating  qna_chkpt_120


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Checkpoint- qna_chkpt_120 loaded
Evaluating on loaded checkpoint
Evaluating  qna_chkpt_140


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Checkpoint- qna_chkpt_140 loaded
Evaluating on loaded checkpoint
Evaluating  qna_chkpt_160


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Checkpoint- qna_chkpt_160 loaded
Evaluating on loaded checkpoint
Evaluating  qna_chkpt_180


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Checkpoint- qna_chkpt_180 loaded
Evaluating on loaded checkpoint
Evaluating  qna_chkpt_200


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Checkpoint- qna_chkpt_200 loaded
Evaluating on loaded checkpoint
Evaluating  qna_chkpt_220


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Checkpoint- qna_chkpt_220 loaded
Evaluating on loaded checkpoint
Evaluating  qna_chkpt_240


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Checkpoint- qna_chkpt_240 loaded
Evaluating on loaded checkpoint
Evaluating  qna_chkpt_260


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Checkpoint- qna_chkpt_260 loaded
Evaluating on loaded checkpoint
Evaluating  qna_chkpt_280


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Checkpoint- qna_chkpt_280 loaded
Evaluating on loaded checkpoint
Evaluating  qna_chkpt_300


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Checkpoint- qna_chkpt_300 loaded
Evaluating on loaded checkpoint
Evaluating  qna_chkpt_320


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Checkpoint- qna_chkpt_320 loaded
Evaluating on loaded checkpoint
Evaluating  qna_chkpt_340


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Checkpoint- qna_chkpt_340 loaded
Evaluating on loaded checkpoint
Evaluating  qna_chkpt_360


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Checkpoint- qna_chkpt_360 loaded
Evaluating on loaded checkpoint
Evaluating  qna_chkpt_380


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Checkpoint- qna_chkpt_380 loaded
Evaluating on loaded checkpoint
Evaluating  qna_chkpt_400


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Checkpoint- qna_chkpt_400 loaded
Evaluating on loaded checkpoint
Evaluating  qna_chkpt_420


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Checkpoint- qna_chkpt_420 loaded
Evaluating on loaded checkpoint
Evaluating  qna_chkpt_440


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Checkpoint- qna_chkpt_440 loaded
Evaluating on loaded checkpoint
Evaluating  qna_chkpt_441


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Checkpoint- qna_chkpt_441 loaded
Evaluating on loaded checkpoint
QnA evaluated on all the saved checkpoints


In [None]:
# Sort perplexity list in increasing order to get best model
perplexity_list.sort(key=lambda x: x[1].item())
top_chkpt_name = perplexity_list[0][0]

print("Top performing checkpoint is- ", top_chkpt_name)

Top performing checkpoint is-  qna_chkpt_441


In [None]:
# Evaluation on dev data done
chkpt_to_load = qna_model_chkpts_dir + top_chkpt_name
print(chkpt_to_load)

/content/drive/MyDrive/QnA/qna_chkpt_441


In [None]:
# just to verify that copied model is loading correctly or not
chkpt_to_load = qna_model_chkpts_dir + top_chkpt_name
model = model_class.from_pretrained(chkpt_to_load)
tokenizer = tokenizer_class.from_pretrained(chkpt_to_load, do_lower_case = True)

print("Model loaded successfully..!!")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Model loaded successfully..!!


In [None]:
# Get output logits from paraphraser
def get_logits(model, idx, sents, segments, past):
  if idx == 0:
      pred = model(input_ids = sents, token_type_ids = segments, return_dict=True)
  else:
      # used the cached representations to speed up decoding
      pred = model(input_ids = sents[:, -1:], token_type_ids = segments[:, -1:], past_key_values = past, return_dict = True)

  logits = pred['logits']
  past_keys = pred['past_key_values']

  return logits, past_keys

In [None]:
# Decide generation lenght and get converted output and score
def generate(model, sents_to_paraphrase, segments, eos_token_id, top_p, top_k, len_to_gen):
  batch_size = sents_to_paraphrase.shape[0] # total sents in batch

  eos_emitted = [False for _ in range(batch_size)]
  scores = [{"score": 0, "sequence": []} for _ in range(batch_size)]

  with torch.no_grad():
    past_keys = None

    for i in range(len_to_gen):
      op_logits, past_keys = get_logits(model, i, sents_to_paraphrase, segments, past_keys)
      next_token_logits = op_logits[:, -1, :]
      original_scores = nn.Softmax(dim = -1)(next_token_logits)

      # do nucleas filtering and greedy decoding
      filtered_logits = next_token_logits
      next_token = torch.argmax(filtered_logits, dim = -1).unsqueeze(-1)

      # if top_k in [0, 1] and top_p == 0.0: # mainly to control the output diversity
      #   # greedy sampling
      #   next_token = torch.argmax(filtered_logits, dim = -1).unsqueeze(-1)
      # else :
      #   next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples = 1)

      sents_to_paraphrase = torch.cat((sents_to_paraphrase, next_token), dim=1)
      segments = torch.cat((segments, segments[:, -1:]), dim=1)

      for batch_elem in range(batch_size):
        if next_token[batch_elem].item() == eos_token_id:
            eos_emitted[batch_elem] = True

      if len_to_gen is None and all(eos_emitted):
        break

  return sents_to_paraphrase, scores

In [None]:
# Preprocess input from to paraphrase
def preprocess2(exp, tokenizer, config):
  max_prefix_len = config["max_prefix_length"]
  sent1 = np.array(exp["sent1_tokens"])

  # truncate
  if(len(sent1) > max_prefix_len):
    sent1 = sent1[:max_prefix_len]

  # add padding; left padding to prefix and right padding to suffix
  count_pad_tokens_prefix = max_prefix_len - len(sent1)
  sent1 = np.pad(sent1, (count_pad_tokens_prefix, 0), constant_values = tokenizer.pad_token_id)

  # sentence to input gpt2
  sentence_to_input_gpt2 = np.concatenate([sent1, [tokenizer.bos_token_id]]).astype(np.int64) # [sent1, <bos>]

  # segment
  segment = np.concatenate([
      [tokenizer.additional_special_tokens_ids[0] for _ in sent1],
      [tokenizer.additional_special_tokens_ids[1]]
  ]).astype(np.int64)

  exp["input"] = sentence_to_input_gpt2
  exp["segment"] = segment

  return exp

In [None]:
# Generate paraphrased sentences for batch of input sentences
def generate_paraphrased_sent(model, top_p, top_k, sents, config, device, tokenizer):
  examples = []

  for sent in sents:
    token_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(sent))
    dd = {"sent1_tokens":token_ids}

    dd_process = preprocess2(dd, tokenizer, config)
    examples.append(dd_process)

  init_context_size = 1 + config["max_prefix_length"]

  output, scores = generate(model.to(device),
      sents_to_paraphrase = torch.tensor([inst["input"] for inst in examples]).to(device),
      segments = torch.tensor([inst["segment"] for inst in examples]).to(device),
      eos_token_id = tokenizer.eos_token_id,
      top_p = top_p, top_k = top_k,
      len_to_gen =  config["max_suffix_length"] + 1  # +1 for <eos>
  )

  all_ops = []
  for idx in range(len(output)):
    exmp = examples[idx]
    curr_out = output[idx, init_context_size:].tolist()

    if tokenizer.eos_token_id in curr_out:
      curr_out = curr_out[:curr_out.index(tokenizer.eos_token_id)]

    all_ops.append(tokenizer.decode(curr_out, clean_up_tokenization_spaces = True, skip_special_tokens = True))

  return all_ops, scores

In [None]:
top_p = 0.0
top_k = 1.0

In [None]:
x,y = generate_paraphrased_sent(model, top_p, top_k, ["How to diagnose Parasites - Cysticercosis?"], INPUT_FORMAT_CONFIG, device, tokenizer)

In [None]:
print(x)

["How is cysticercosis diagnosed? Cysticercosis is diagnosed by a doctor's examination of the skin and mucous membranes. The doctor will examine the skin and mucous membranes to find out if the cysticerc"]
