# Training

## module test

In [1]:
from char_mlm import CharMLMDataset

test = CharMLMDataset(
    masked_texts=['t[MASK]st', 'hel[MASK]o'],
    label_texts=['test', 'hello']
)

print(test.batch_encoding)
print(test.tokenizer.decode(test[0]['input_ids']))
print(test.tokenizer.decode(test[0]['labels']))


Inputs: Encoding texts...: 100%|██████████| 2/2 [00:00<00:00, 6517.95it/s]
Labels: Encoding texts...: 100%|██████████| 2/2 [00:00<00:00, 28339.89it/s]

{'input_ids': tensor([[101, 316, 103, 315, 316, 102,   0],
        [101, 304, 301, 308, 103, 311, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[101, 316, 301, 315, 316, 102,   0],
        [101, 304, 301, 308, 308, 311, 102]])}
[CLS]t[MASK]st[SEP][PAD]
[CLS]test[SEP][PAD]





## loading dataset

In [1]:
from char_mlm import CharMLMDataset
from typing import List, Union
import pandas as pd

def mask_idx(text: str, idx: Union[int, List[int]]) -> str:
    text = list(text)

    if type(idx) == int: idx = [idx]
    for i in idx:
        text[i] = '[MASK]'

    return ''.join(text)


def mask_sents(sents_origin: List[str]):
    sents, sents_masked = [], []
    for sent in sents_origin:
        for i in range(len(sent)):
            sents_masked.append(mask_idx(sent, i))
            sents.append(sent)
    return sents_masked, sents

sents_origin = pd.read_csv('./Data/en_setence.csv').clean.to_list()[:2000]
test_sents_origin, train_sents_origin  = sents_origin[:len(sents_origin)//10], sents_origin[len(sents_origin)//10:]

train = CharMLMDataset(*mask_sents(train_sents_origin))
test = CharMLMDataset(*mask_sents(test_sents_origin))

print(f'train: {len(train)}, test: {len(test)}')


Inputs: Encoding texts...: 100%|██████████| 216308/216308 [00:06<00:00, 31928.20it/s]
Labels: Encoding texts...: 100%|██████████| 216308/216308 [00:06<00:00, 34415.22it/s]
Inputs: Encoding texts...: 100%|██████████| 19650/19650 [00:00<00:00, 32648.31it/s]
Labels: Encoding texts...: 100%|██████████| 19650/19650 [00:00<00:00, 37978.96it/s]

train: 216308, test: 19650





## Trainer & Model definition

In [2]:
from transformers import Trainer, BertForMaskedLM, BertConfig, TrainingArguments
import os
from datetime import datetime
import torch
import torch_ort

# os.environ['CUDA_VISIBLE_DEVICES'] = '0'

MODEL_DIR = os.path.join(
    './models', '2021-11-23-20-10-16' # datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
)

model_config = BertConfig(
    max_position_embeddings=1024,
)
model = BertForMaskedLM(model_config)
# model = torch_ort.ORTModule(model)
# model.to(torch.device('cuda:0')) # model default is cuda:0

training_args = TrainingArguments(
    output_dir=MODEL_DIR,
    num_train_epochs=10,
    # evaluation_strategy='epoch',
    logging_dir=os.path.join(MODEL_DIR, 'tensorboard'),
    logging_strategy='epoch',
    log_level='warning',
    save_strategy='epoch'
)
training_args._n_gpu = 1

trainer = Trainer(
    model,
    training_args,
    train_dataset=train,
)

print('model:', model.device)
print('trainer:', training_args.device)


model: cuda:0
trainer: cuda:0


## training

In [3]:
trainer.train(resume_from_checkpoint=True)
trainer.save_model()
test_result = trainer.evaluate(test)
print(test_result)

0it [00:00, ?it/s]

Step,Training Loss
54078,0.0033


## training result on tensorboard

In [None]:
%tensorboard --logdir models/2021-11-22-20-30-31/tensorboard/

# prediction

In [None]:
from char_mlm import CharMLMDataset, CharTokenizer
from transformers import AutoModelForMaskedLM
model = AutoModelForMaskedLM.from_pretrained(
    './models/2021-11-22-20-30-31/checkpoint-31566')
tokenizer = CharTokenizer()


In [None]:
import torch
inputs = ['[MASK]ello there!']
outputs = tokenizer.batch_decode(torch.argmax(model(**tokenizer(inputs))['logits'], -1))
print('inputs:')
print('\n'.join(['   ' + i for i in inputs]))
print('outputs:')
print('\n'.join(['   ' + o.replace("[PAD]", "") for o in outputs]))


Encoding texts...: 100%|██████████| 1/1 [00:00<00:00, 9425.40it/s]

inputs:
   [MASK]ello there!
outputs:
   [CLS]mello tmerem[SEP]



