In [None]:
import transformers
import torch
from transformers import (
    BertModel, 
    BertTokenizer, 
    AutoModelForMaskedLM, 
    DataCollatorForLanguageModeling,
    Trainer
)
import datasets
from datasets import load_dataset, load_metric
import pandas as pd
import os
import numpy as np
from torch.utils.data import DataLoader
import tqdm

# Train LM

In [None]:
model_name = "Rostlab/prot_bert_bfd"
dataset_name = "sequences"
cache_dir = "./cache"
validation_split_percentage = 5

In [None]:
model = AutoModelForMaskedLM.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)

In [None]:
raw_datasets = load_dataset(
    "text",
    data_files=data_files,
    cache_dir=cache_dir
)

In [None]:
if "validation" not in raw_datasets.keys():
    raw_datasets["validation"] = load_dataset(
        "text",
        data_files=data_files,
        split=f"train[:{validation_split_percentage}%]",
        cache_dir=cache_dir
    )
    raw_datasets["train"] = load_dataset(
        "text",
        data_files=data_files,
        split=f"train[{validation_split_percentage}%:]",
        cache_dir=cache_dir
    )

In [None]:
data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm_probability=0.15,
        pad_to_multiple_of=8
)

In [None]:
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    # preds have the same shape as the labels, after the argmax(-1) has been calculated
    # by preprocess_logits_for_metrics
    labels = labels.reshape(-1)
    preds = preds.reshape(-1)
    mask = labels != -100
    labels = labels[mask]
    preds = preds[mask]
    return metric.compute(predictions=preds, references=labels)

In [None]:
def preprocess_logits_for_metrics(logits, labels):
    if isinstance(logits, tuple):
        # Depending on the model and config, logits may contain extra tensors,
        # like past_key_values, but logits always come first
        logits = logits[0]
    return logits.argmax(dim=-1)

In [None]:
def tokenize_function(examples):
    # Remove empty lines
    examples["text"] = [
        line for line in examples["text"] if len(line) > 0 and not line.isspace()
    ]
    
    return tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=128,
        # We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it
        # receives the `special_tokens_mask`.
        return_special_tokens_mask=True,
    )

In [None]:
tokenized_datasets = raw_datasets.map(
    tokenize_function,
    batched=True,
    desc="Running tokenizer on dataset line_by_line",
    load_from_cache_file=True,
    remove_columns=["text"]
)

In [None]:
train_dataset = tokenized_datasets["train"]
eval_dataset = tokenized_datasets["validation"]

In [None]:
train_data_loader = DataLoader(train_dataset, collate_fn=data_collator, batch_size=8)

In [None]:
trainer = Trainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    preprocess_logits_for_metrics=preprocess_logits_for_metrics
)

In [None]:
trainer.train()