<a href="https://colab.research.google.com/github/futugyou/pyproject/blob/master/google_colab/generation_representation_model_03.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install required dependencies
%pip install datasets
%pip install sentence_transformers
%pip install transformers
%pip install torch
%pip install tqdm
%pip install scikit-learn

In [None]:
from datasets import load_dataset

tomatoes = load_dataset("rotten_tomatoes")
train_dataset, test_dataset = tomatoes["train"], tomatoes["test"]

In [None]:
from transformers import AutoTokenizer, AutoModelForMaskedLM

model_id = "bert-base-cased"
model = AutoModelForMaskedLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
def preprocess_function(examples):
    # Tokenize the texts
    return tokenizer(examples["text"], truncation=True)


tokenized_train = train_dataset.map(preprocess_function, batched=True)
tokenized_train = tokenized_train.remove_columns("label")
tokenized_test = test_dataset.map(preprocess_function, batched=True)
tokenized_test = tokenized_test.remove_columns("label")

In [None]:
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=True,
    mlm_probability=0.15,
)

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    "model",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=10,
    weight_decay=0.01,
    save_strategy="epoch",
    report_to="none",
)

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

tokenizer.save_pretrained("mlm")
trainer.train()
model.save_pretrained("mlm")

In [None]:
from transformers import pipeline

mask_filler = pipeline("fill-mask", model="mlm")
preds = mask_filler("what a horrible [MASK]!")

for pred in preds:
    print(f">>> {pred['sequence']}")