In [7]:
import transformers
import torch
from transformers import (
    BertModel, 
    BertTokenizer, 
    AutoModelForMaskedLM, 
    DataCollatorForLanguageModeling,
    Trainer
)
import datasets
from datasets import load_dataset, load_metric
import pandas as pd
import os
import numpy as np
from torch.utils.data import DataLoader
import tqdm

# Train LM

In [None]:
model_name = "Rostlab/prot_bert_bfd"
dataset_name = "sequences"
cache_dir = "./cache"
validation_split_percentage = 5

In [None]:
model = AutoModelForMaskedLM.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)

In [101]:
raw_datasets = load_dataset(
    "text",
    data_files=data_files,
    cache_dir=cache_dir
)

Using custom data configuration default-0ef9a2c30a4c85ff
Reusing dataset text (./cache/text/default-0ef9a2c30a4c85ff/0.0.0/4b86d314f7236db91f0a0f5cda32d4375445e64c5eda2692655dd99c2dac68e8)


  0%|          | 0/1 [00:00<?, ?it/s]

In [102]:
if "validation" not in raw_datasets.keys():
    raw_datasets["validation"] = load_dataset(
        "text",
        data_files=data_files,
        split=f"train[:{validation_split_percentage}%]",
        cache_dir=cache_dir
    )
    raw_datasets["train"] = load_dataset(
        "text",
        data_files=data_files,
        split=f"train[{validation_split_percentage}%:]",
        cache_dir=cache_dir
    )

Using custom data configuration default-0ef9a2c30a4c85ff
Reusing dataset text (./cache/text/default-0ef9a2c30a4c85ff/0.0.0/4b86d314f7236db91f0a0f5cda32d4375445e64c5eda2692655dd99c2dac68e8)
Using custom data configuration default-0ef9a2c30a4c85ff
Reusing dataset text (./cache/text/default-0ef9a2c30a4c85ff/0.0.0/4b86d314f7236db91f0a0f5cda32d4375445e64c5eda2692655dd99c2dac68e8)


In [103]:
data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm_probability=0.15,
        pad_to_multiple_of=8
)

In [104]:
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    # preds have the same shape as the labels, after the argmax(-1) has been calculated
    # by preprocess_logits_for_metrics
    labels = labels.reshape(-1)
    preds = preds.reshape(-1)
    mask = labels != -100
    labels = labels[mask]
    preds = preds[mask]
    return metric.compute(predictions=preds, references=labels)

In [105]:
def preprocess_logits_for_metrics(logits, labels):
    if isinstance(logits, tuple):
        # Depending on the model and config, logits may contain extra tensors,
        # like past_key_values, but logits always come first
        logits = logits[0]
    return logits.argmax(dim=-1)

In [106]:
def tokenize_function(examples):
    # Remove empty lines
    examples["text"] = [
        line for line in examples["text"] if len(line) > 0 and not line.isspace()
    ]
    
    return tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=128,
        # We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it
        # receives the `special_tokens_mask`.
        return_special_tokens_mask=True,
    )

In [107]:
tokenized_datasets = raw_datasets.map(
    tokenize_function,
    batched=True,
    desc="Running tokenizer on dataset line_by_line",
    load_from_cache_file=True,
    remove_columns=["text"]
)

Loading cached processed dataset at ./cache/text/default-0ef9a2c30a4c85ff/0.0.0/4b86d314f7236db91f0a0f5cda32d4375445e64c5eda2692655dd99c2dac68e8/cache-7f96015c4986f7cd.arrow
Loading cached processed dataset at ./cache/text/default-0ef9a2c30a4c85ff/0.0.0/4b86d314f7236db91f0a0f5cda32d4375445e64c5eda2692655dd99c2dac68e8/cache-f5081781e74c7a2b.arrow


In [108]:
train_dataset = tokenized_datasets["train"]
eval_dataset = tokenized_datasets["validation"]

In [133]:
train_data_loader = DataLoader(train_dataset, collate_fn=data_collator, batch_size=8)

In [109]:
trainer = Trainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    preprocess_logits_for_metrics=preprocess_logits_for_metrics
)

No `TrainingArguments` passed, using `output_dir=tmp_trainer`.
PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [110]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `BertForMaskedLM.forward` and have been ignored: special_tokens_mask. If special_tokens_mask are not expected by `BertForMaskedLM.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 166962
  Num Epochs = 3
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 62613


Step,Training Loss


KeyboardInterrupt: 