In [1]:
import numpy as np
from datasets import load_metric
from datasets import load_from_disk
from transformers import DataCollatorForLanguageModeling
from transformers import DistilBertForMaskedLM, DistilBertTokenizer

def compute_metrics(pred):
    labels = np.array(pred.label_ids)
    preds = np.array(pred.predictions.argmax(-1))
    metric = load_metric('accuracy')
    
    acc = metric.compute(predictions=preds[labels > -100], references=labels[labels > -100])['accuracy']

    return {'accuracy': acc}

model = DistilBertForMaskedLM.from_pretrained("distilbert-base-uncased")
# model = DistilBertForMaskedLM.from_pretrained("BERT_uniprot_mlm/checkpoint-28000/", local_files_only=True)
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

dataset = load_from_disk("prepared_uniprot")

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)

In [2]:
# from transformers import Trainer, TrainingArguments

# training_args = TrainingArguments(
#     output_dir="./BERT_uniprot_mlm",
#     overwrite_output_dir=True,
#     num_train_epochs=5,
#     per_gpu_train_batch_size=32,
#     dataloader_num_workers=16,
#     save_steps=1000,
#     save_total_limit=2,
#     prediction_loss_only=True,
#     warmup_steps=1000,
#     weight_decay=0.01,
#     fp16=True, 
#     learning_rate=0.00005,
#     logging_strategy="steps",
#     logging_steps=100
# )

# trainer = Trainer(
#     model=model,
#     args=training_args,
#     data_collator=data_collator,
#     train_dataset=dataset['train'],
#     eval_dataset=dataset['test'],
#     compute_metrics=compute_metrics
# )

In [3]:
# trainer.train()

In [5]:
# trainer.evaluate()

{'eval_loss': 0.46399587392807007,
 'eval_runtime': 387.3582,
 'eval_samples_per_second': 239.775,
 'eval_steps_per_second': 14.986,
 'epoch': 5.0}

In [4]:
# import torch

# torch.save(model.state_dict(),"BERT_uniprot_mlm/tuned_model.pt")

In [2]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./BERT_uniprot_mlm",
    per_gpu_train_batch_size=256,
    dataloader_num_workers=32,
    fp16=True, 
    eval_accumulation_steps=2,
    evaluation_strategy="epoch"
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset['train'],
    eval_dataset=dataset["test"].select(list(range(1000))),
    compute_metrics=compute_metrics
)

trainer.evaluate()

Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future version. Using `--per_device_train_batch_size` is preferred.


  metric = load_metric('accuracy')
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


{'eval_loss': 10.18470287322998,
 'eval_accuracy': 0.061654135338345864,
 'eval_runtime': 295.0402,
 'eval_samples_per_second': 3.389,
 'eval_steps_per_second': 0.214}

In [13]:
# Fine-tuned

from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./BERT_uniprot_mlm",
    per_gpu_train_batch_size=256,
    dataloader_num_workers=32,
    fp16=True, 
    eval_accumulation_steps=2,
    evaluation_strategy="epoch"
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset['train'],
    eval_dataset=dataset["test"].select(list(range(1000))),
    compute_metrics=compute_metrics
)

trainer.evaluate()

Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future version. Using `--per_device_train_batch_size` is preferred.


  metric = load_metric('accuracy')
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Downloading builder script:   0%|          | 0.00/1.65k [00:00<?, ?B/s]

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


{'eval_loss': 0.4505701959133148,
 'eval_accuracy': 0.9045112781954887,
 'eval_runtime': 303.8642,
 'eval_samples_per_second': 3.291,
 'eval_steps_per_second': 0.207}