In [1]:
#from /opt/conda/lib/python3.8/site-packages/transformers/trainer.py and modified at 2310,2311 and 2335, 2336
from trainer import Trainer

In [2]:
# from transformers import DistilBertTokenizer, DistilBertForMaskedLM, DataCollatorForLanguageModeling, Trainer, EarlyStoppingCallback, TrainingArguments
from transformers import DistilBertTokenizer, DistilBertForMaskedLM, DataCollatorForLanguageModeling, EarlyStoppingCallback, TrainingArguments
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import evaluate
import torch

In [3]:
import json
import os

In [4]:
name_model = 'distilbert-base-uncased'
model = DistilBertForMaskedLM.from_pretrained(name_model)
tokenizer = DistilBertTokenizer.from_pretrained(name_model)
collator = DataCollatorForLanguageModeling(tokenizer, mlm_probability=0.35)

In [5]:
def prepare_one_row(tokens_raw, seq_max):
    tokens_0 = [101] + tokens_raw + [102]
    indexs_special_0 = [1] + [0 for _ in range(len(tokens_0) - 2)]+ [1]
    
    sample_0 = {'input_ids': tokens_0, 'special_tokens_mask': indexs_special_0}    # return_special_tokens_mask=True
    return sample_0
# end

In [6]:
class SimpleMLMDataset(torch.utils.data.Dataset):
    TOKEN_CLS = 101
    TOKEN_SEP = 102
    SEQ_MAX = 258
    
    def __init__(self, paths_file, len_dataset=-1):
        self.paths_file_all = paths_file
        self.seq_max = SimpleMLMDataset.SEQ_MAX
        
        self.index_file = -1 
        
        self.file_current = None
        self.rows_current = []
        self.len_dataset = len_dataset
    # end


    def __getitem__(self, idx):   # should not have problem now
        
        row_current = None
        
        while not row_current:
            try:
                if not self.rows_current:
                    self.index_file = (self.index_file + 1) % len(self.paths_file_all)
                    path_file_current = self.paths_file_all[self.index_file]
                    # print('procceed to: {}'.format(path_file_current))
                    with open(path_file_current, 'r') as file:
                        self.rows_current = json.load(file)
                    # end
                # end

                row_current = self.rows_current.pop()
            except Exception as ex:
                continue
            # end
        # end
        
        
        sample = self.transfer_one_row_to_sample(row_current, self.seq_max) #[2335, 15464,  2361,   103,  8654,  2188,  7309,  2573, ...] -> 258 (should be 256 but has bug)

        return sample
    # end

    def __len__(self):
        return self.len_dataset
    # end
    
    def transfer_one_row_to_sample(self, tokens_raw, seq_max):
        tokens_0 = [101] + tokens_raw + [102]
        indexs_special_0 = [1] + [0 for _ in range(len(tokens_0) - 2)]+ [1]

        num_padding = seq_max - len(tokens_0)
        if num_padding < 0:
            num_padding = 0
        # end

        if num_padding:
            tokens_padded = tokens_0 + [0] * num_padding
            indexs_special_padded = indexs_special_0 + [1] * num_padding
        else:
            tokens_padded = tokens_0
            indexs_special_padded = indexs_special_0
        # end

        sample_0 = {'input_ids': tokens_padded, 'special_tokens_mask': indexs_special_padded}    # return_special_tokens_mask=True
        return sample_0
    # end


# end

In [7]:
# path_folder_data = 'data'
path_folder_data = 'full_debug_encoded_all_half_01'
num_data_train = 4634404

filenames_data_train = [name for name in os.listdir(path_folder_data) if name[0] != '.']
paths_file_train = [os.path.join(path_folder_data, filename_data_train) for filename_data_train in filenames_data_train]

# path_folder_data_eval = 'data_eval_2023'
# num_data_eval = 209

path_folder_data_eval = 'data_eval'
num_data_eval = 813

filenames_data_eval = [name for name in os.listdir(path_folder_data_eval) if name[0] != '.']
paths_file_eval = [os.path.join(path_folder_data_eval, filename_data_eval) for filename_data_eval in filenames_data_eval]

dataset_train = SimpleMLMDataset(paths_file_train, num_data_train)
dataset_eval = SimpleMLMDataset(paths_file_eval, num_data_eval)  # 2022: 604, 2023:209, data_eval: 813

In [8]:
def preprocess_logits_for_metrics(logits, labels):
    if isinstance(logits, tuple):
        # Depending on the model and config, logits may contain extra tensors,
        # like past_key_values, but logits always come first
        logits = logits[0]
    return logits.argmax(dim=-1)
# end

metric = evaluate.load("accuracy")

# compute_metrics for the trainer
# check np_test for detail
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    
    if isinstance(preds, tuple):
        preds = preds[0]
    # end
    
    # preds = logits.argmax(dim=-1)
    # preds = logits.argmax(axis=-1)
    # preds have the same shape as the labels, after the argmax(-1) has been calculated
    # by preprocess_logits_for_metrics
    labels = labels.reshape(-1)
    preds = preds.reshape(-1)
    mask = labels != -100
    labels = labels[mask]
    preds = preds[mask]
    return metric.compute(predictions=preds, references=labels)
# end

# original compute_metrics
# def compute_metrics(eval_preds):
#     logits, labels = eval_preds
    
#     if isinstance(logits, tuple):
#         logits = logits[0]
#     # end
    
#     # preds = logits.argmax(dim=-1)
#     preds = logits.argmax(axis=-1)
#     # preds have the same shape as the labels, after the argmax(-1) has been calculated
#     # by preprocess_logits_for_metrics
#     labels = labels.reshape(-1)
#     preds = preds.reshape(-1)
#     mask = labels != -100
#     labels = labels[mask]
#     preds = preds[mask]
#     return metric.compute(predictions=preds, references=labels)

In [9]:
training_args = TrainingArguments(
    output_dir='./outputs',  # output directory
    num_train_epochs=4,  # total number of training epochs
    per_device_train_batch_size=8,  # batch size per device during training
    per_device_eval_batch_size=8,  # batch size for evaluation
    warmup_steps=0,  # number of warmup steps for learning rate scheduler
    weight_decay=0.01,  # strength of weight decay
    logging_dir='./logs',  # directory for storing logs
    load_best_model_at_end=True,
    # load the best model when finished training (default metric is loss)    # but you can specify `metric_for_best_model` argument to change to accuracy or other metric
    logging_steps=10,  # log & save weights each logging_steps
    evaluation_strategy="epoch",  # evaluate each `logging_steps`
    learning_rate=2e-5,
    save_strategy='epoch',
    save_total_limit=1,
    metric_for_best_model='accuracy',
    eval_accumulation_steps=16
    # prediction_loss_only=True
    # predict_with_generate=True,
    # auto_find_batch_size=True
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset_train,
    eval_dataset=dataset_eval,
    tokenizer=tokenizer,
    data_collator=collator,
    compute_metrics=compute_metrics
)


In [10]:
trainer.train()

Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

In [None]:
print('done')

In [None]:
# from transformers.trainer_pt_utils import nested_concat

In [None]:
# nested_concat.__code__