In [None]:
%pip install accelerate -U -qqq
%pip install transformers[torch] -qqq

In [1]:
import warnings
warnings.filterwarnings("ignore")

In [2]:
import torch
import yaml
from distutils.dir_util import copy_tree

In [3]:
import torch
import yaml
from src.data.datamodule import DataManager

In [6]:
import json
import os
import numpy as np
import re

from tqdm import tqdm

from src.data.mt_dataset import MTDataset_HF
from src.data.tokenizers.unif_tokenizers import UNIFTokenizer

In [60]:
data_config = yaml.load(open("configs/data_config.yaml", 'r', encoding='utf-8'), Loader=yaml.Loader)
data_path = lambda x: data_config["path_repository"] + "data/" + data_config["data_language"] + str(x) + data_config["data_name_file"] + ".json"

In [61]:
config = data_config
device = "cpu"

tokenizer = UNIFTokenizer(path_tok=config["path_repository"] + "data/query_vocab.json",
                               pre_train_name=config["pre_train_tokenizer"],
                               pad_flag=True,
                               max_length=config["max_sent_len"])

def prepare_data(path_data, drop_last=False):

    dev_data = json.load(open(os.path.join(path_data), 'r', encoding="utf-8"))
    target_sentences = []
    source_sentences = []
    for sample in tqdm(dev_data[:config["separate_batch"]], desc="Pars data"):
        target_sentences.append(sample['masked_query'])
        source_sentences.append(sample['question'])

    # DataLoader

    tokenized_source_sentences = [tokenizer.tkr(i) for i in source_sentences]#[0:10]
    tokenized_target_sentences = [tokenizer.tkr(i) for i in target_sentences]#[0:10]

    dataset = MTDataset_HF(tokenized_source_list=tokenized_source_sentences,
                        tokenized_target_list=tokenized_target_sentences, device=device)
    return dataset

In [62]:
dev_dataloader = prepare_data(path_data=data_path("dev"), drop_last=False)
test_dataloader = prepare_data(path_data=data_path("test"), drop_last=True)


Pars data: 100%|██████████| 8420/8420 [00:00<00:00, 1791782.84it/s]
 22%|██▏       | 7000/31590 [1:43:46<6:04:33,  1.12it/s]
Pars data: 100%|██████████| 15877/15877 [00:00<00:00, 1578705.72it/s]


In [63]:
from transformers import XLMRobertaForCausalLM, AutoConfig
config = AutoConfig.from_pretrained("FacebookAI/roberta-base")
config.is_decoder = True
model = XLMRobertaForCausalLM.from_pretrained("FacebookAI/roberta-base", config=config).to('cuda')

In [64]:
sum(p.numel() for p in model.parameters())

124697433

In [65]:
from transformers import Trainer, TrainingArguments

In [66]:
training_args = TrainingArguments(
    f"roberta-base-exp",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,# 2e-5,
    weight_decay=0.01,
    num_train_epochs = 3,
    logging_dir = 'logs',
    save_strategy="no"
)

In [67]:
from transformers import ProgressCallback, PrinterCallback

In [68]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dev_dataloader,
    eval_dataset=test_dataloader,
    callbacks = [PrinterCallback]
)

In [69]:
end_train = trainer.train()

 16%|█▌        | 500/3159 [00:30<02:40, 16.53it/s]

{'loss': 1.1346, 'learning_rate': 1.6834441278885725e-05, 'epoch': 0.47}
{'loss': 1.1346, 'learning_rate': 1.6834441278885725e-05, 'epoch': 0.47}


 32%|███▏      | 1000/3159 [01:01<02:12, 16.31it/s]

{'loss': 0.8333, 'learning_rate': 1.3668882557771448e-05, 'epoch': 0.95}
{'loss': 0.8333, 'learning_rate': 1.3668882557771448e-05, 'epoch': 0.95}


 33%|███▎      | 1053/3159 [01:04<02:04, 16.90it/s]
  0%|          | 0/1985 [00:00<?, ?it/s][A
  0%|          | 7/1985 [00:00<00:28, 69.06it/s][A
  1%|          | 14/1985 [00:00<00:32, 61.34it/s][A
  1%|          | 21/1985 [00:00<00:33, 59.29it/s][A
  1%|▏         | 27/1985 [00:00<00:33, 58.11it/s][A
  2%|▏         | 33/1985 [00:00<00:35, 55.52it/s][A
  2%|▏         | 39/1985 [00:00<00:34, 56.72it/s][A
  2%|▏         | 46/1985 [00:00<00:33, 57.74it/s][A
  3%|▎         | 53/1985 [00:00<00:32, 58.57it/s][A
  3%|▎         | 59/1985 [00:01<00:32, 58.68it/s][A
  3%|▎         | 65/1985 [00:01<00:32, 59.05it/s][A
  4%|▎         | 71/1985 [00:01<00:32, 59.10it/s][A
  4%|▍         | 78/1985 [00:01<00:32, 59.57it/s][A
  4%|▍         | 84/1985 [00:01<00:31, 59.58it/s][A
  5%|▍         | 91/1985 [00:01<00:31, 59.73it/s][A
  5%|▍         | 97/1985 [00:01<00:31, 59.63it/s][A
  5%|▌         | 103/1985 [00:01<00:31, 59.33it/s][A
  5%|▌         | 109/1985 [00:01<00:31, 59.15it/s][A
  

{'eval_loss': 0.797120988368988, 'eval_runtime': 33.7479, 'eval_samples_per_second': 470.458, 'eval_steps_per_second': 58.818, 'epoch': 1.0}
{'eval_loss': 0.797120988368988, 'eval_runtime': 33.7479, 'eval_samples_per_second': 470.458, 'eval_steps_per_second': 58.818, 'epoch': 1.0}


 47%|████▋     | 1500/3159 [02:05<01:42, 16.25it/s]  

{'loss': 0.7622, 'learning_rate': 1.0503323836657171e-05, 'epoch': 1.42}
{'loss': 0.7622, 'learning_rate': 1.0503323836657171e-05, 'epoch': 1.42}


 63%|██████▎   | 2000/3159 [02:35<01:09, 16.58it/s]

{'loss': 0.724, 'learning_rate': 7.337765115542894e-06, 'epoch': 1.9}
{'loss': 0.724, 'learning_rate': 7.337765115542894e-06, 'epoch': 1.9}


 67%|██████▋   | 2106/3159 [02:42<01:05, 16.19it/s]
  0%|          | 0/1985 [00:00<?, ?it/s][A
  0%|          | 7/1985 [00:00<00:29, 67.54it/s][A
  1%|          | 14/1985 [00:00<00:32, 61.26it/s][A
  1%|          | 21/1985 [00:00<00:33, 59.40it/s][A
  1%|▏         | 27/1985 [00:00<00:33, 58.18it/s][A
  2%|▏         | 33/1985 [00:00<00:33, 57.72it/s][A
  2%|▏         | 39/1985 [00:00<00:33, 57.53it/s][A
  2%|▏         | 45/1985 [00:00<00:33, 57.42it/s][A
  3%|▎         | 51/1985 [00:00<00:33, 57.31it/s][A
  3%|▎         | 57/1985 [00:00<00:33, 57.48it/s][A
  3%|▎         | 63/1985 [00:01<00:33, 57.45it/s][A
  3%|▎         | 69/1985 [00:01<00:33, 57.55it/s][A
  4%|▍         | 75/1985 [00:01<00:33, 57.74it/s][A
  4%|▍         | 81/1985 [00:01<00:32, 57.88it/s][A
  4%|▍         | 87/1985 [00:01<00:32, 58.08it/s][A
  5%|▍         | 93/1985 [00:01<00:32, 58.31it/s][A
  5%|▍         | 99/1985 [00:01<00:32, 58.38it/s][A
  5%|▌         | 105/1985 [00:01<00:32, 58.66it/s][A
  6

{'eval_loss': 0.7641430497169495, 'eval_runtime': 33.8832, 'eval_samples_per_second': 468.58, 'eval_steps_per_second': 58.584, 'epoch': 2.0}
{'eval_loss': 0.7641430497169495, 'eval_runtime': 33.8832, 'eval_samples_per_second': 468.58, 'eval_steps_per_second': 58.584, 'epoch': 2.0}


 79%|███████▉  | 2501/3159 [03:40<00:40, 16.38it/s]  

{'loss': 0.7033, 'learning_rate': 4.172206394428617e-06, 'epoch': 2.37}
{'loss': 0.7033, 'learning_rate': 4.172206394428617e-06, 'epoch': 2.37}


 95%|█████████▌| 3003/3159 [04:10<00:09, 16.31it/s]

{'loss': 0.6926, 'learning_rate': 1.00664767331434e-06, 'epoch': 2.85}
{'loss': 0.6926, 'learning_rate': 1.00664767331434e-06, 'epoch': 2.85}


100%|██████████| 3159/3159 [04:20<00:00, 16.85it/s]
  0%|          | 0/1985 [00:00<?, ?it/s][A
  0%|          | 8/1985 [00:00<00:28, 68.92it/s][A
  1%|          | 15/1985 [00:00<00:31, 63.39it/s][A
  1%|          | 22/1985 [00:00<00:31, 61.51it/s][A
  1%|▏         | 29/1985 [00:00<00:32, 60.52it/s][A
  2%|▏         | 36/1985 [00:00<00:32, 60.03it/s][A
  2%|▏         | 43/1985 [00:00<00:32, 59.59it/s][A
  2%|▏         | 49/1985 [00:00<00:32, 59.53it/s][A
  3%|▎         | 55/1985 [00:00<00:32, 59.32it/s][A
  3%|▎         | 61/1985 [00:01<00:32, 59.34it/s][A
  3%|▎         | 67/1985 [00:01<00:32, 59.27it/s][A
  4%|▎         | 73/1985 [00:01<00:32, 58.82it/s][A
  4%|▍         | 79/1985 [00:01<00:32, 58.88it/s][A
  4%|▍         | 85/1985 [00:01<00:32, 58.91it/s][A
  5%|▍         | 91/1985 [00:01<00:32, 59.16it/s][A
  5%|▍         | 97/1985 [00:01<00:32, 58.94it/s][A
  5%|▌         | 103/1985 [00:01<00:32, 58.63it/s][A
  5%|▌         | 109/1985 [00:01<00:31, 58.83it/s][A
  

{'eval_loss': 0.7546016573905945, 'eval_runtime': 33.8176, 'eval_samples_per_second': 469.489, 'eval_steps_per_second': 58.697, 'epoch': 3.0}
{'eval_loss': 0.7546016573905945, 'eval_runtime': 33.8176, 'eval_samples_per_second': 469.489, 'eval_steps_per_second': 58.697, 'epoch': 3.0}
{'train_runtime': 294.1969, 'train_samples_per_second': 85.861, 'train_steps_per_second': 10.738, 'train_loss': 0.802581847487616, 'epoch': 3.0}
{'train_runtime': 294.1969, 'train_samples_per_second': 85.861, 'train_steps_per_second': 10.738, 'train_loss': 0.802581847487616, 'epoch': 3.0}





In [70]:
import math
eval_results = trainer.evaluate()
print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

100%|██████████| 1985/1985 [00:33<00:00, 59.24it/s]

{'eval_loss': 0.7546016573905945, 'eval_runtime': 33.5232, 'eval_samples_per_second': 473.612, 'eval_steps_per_second': 59.213, 'epoch': 3.0}
Perplexity: 2.13





In [87]:
import random

In [88]:
trainer.model.to('cpu');
for _ in range(5):
    index = random.randint(0,len(test_dataloader)-1)
    print(f"---------true-----------")
    print(tokenizer.decode(test_dataloader[index]['input_ids']), tokenizer.decode(test_dataloader[index]['labels']))
    print(f"---------predict-----------")
    predict = trainer.model.generate(test_dataloader[index]['input_ids'].reshape((1,-1)), max_new_tokens=40, do_sample=True, top_k=50, top_p=0.95)[0]
    print(tokenizer.decode(predict))

---------true-----------
[CLS] What is the number in season of the episode whose production code is pabf05 ? [SEP]                    [CLS] SELECT # FROM table WHERE Production_code = STR_VALUE_1 [SEP]                              
---------predict-----------
[CLS] What is the number in season of the episode whose production code is pabf05 ? [SEP]                                SELECT   SELECT        SELECT  SELECT     SELECT  SELECT   SELECT  SELECT  
---------true-----------
[CLS] What is the highest pick number for player don barber ? [SEP]                           [CLS] SELECT MAX Pick_# FROM table WHERE Player = STR_VALUE_1 [SEP]                             
---------predict-----------
[CLS] What is the highest pick number for player don barber ? [SEP]                           SELECT Venue WHERE  SELECT Home_team   SELECT AND    SELECT         SELECT         SELECT FROM FROM Position Score Title = NUM_VALUE_1 Player
---------true-----------
[CLS] What was the original air date f

In [90]:
trainer.model.to('cpu');
for _ in range(4):
    index = random.randint(0,len(dev_dataloader)-1)
    print(f"---------true-----------")
    print(tokenizer.decode(dev_dataloader[index]['input_ids']), tokenizer.decode(dev_dataloader[index]['labels']))
    print(f"---------predict-----------")
    predict = trainer.model.generate(dev_dataloader[index]['input_ids'].reshape((1,-1)), max_new_tokens=40, do_sample=True, top_k=50, top_p=0.95)[0]
    print(tokenizer.decode(predict))

---------true-----------
[CLS] What is the Deputy's affiliation in 1992 – 93 ? [SEP]                           [CLS] SELECT Deputy's_affiliation FROM table WHERE Year = STR_VALUE_1 [SEP]                              
---------predict-----------
[CLS] What is the Deputy's affiliation in 1992 – 93 ? [SEP]                           SELECT Result  SELECT     SELECT Grid Grid      SELECT FROM       SELECT  SELECT   SELECT FROM FROM table FROM Rank = = = NUM_VALUE_1 Score
---------true-----------
[CLS] What was the first leg of the semi - final ? [SEP]                            [CLS] SELECT First_leg FROM table WHERE Round = STR_VALUE_1 [SEP]                              
---------predict-----------
[CLS] What was the first leg of the semi - final ? [SEP]                            SELECT Bronze FROM  SELECT   SELECT STR_VALUE_3 AND Total = Opponent WHERE = Points = = = =                    
---------true-----------
[CLS] Which venue had an extra of Junior Race ? [SEP]                      

In [76]:
metrics = end_train.metrics

# save train results
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)

***** train metrics *****
  epoch                    =        3.0
  train_loss               =     0.8026
  train_runtime            = 0:04:54.19
  train_samples_per_second =     85.861
  train_steps_per_second   =     10.738
