# Fine tune base model

1. Import dependencies

In [None]:
import random

import numpy as np
import torch
import os

from transformers import (
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    AutoModelForMaskedLM,
    TrainingArguments,
    Trainer,
)
from datasets import DatasetDict

from src.model.mlm_metrics import compute_metrics
from src.util.torch_device import resolve_torch_device
from src.definitions import (
    MODELS_FOLDER,
    PROCESSED_DATA_FOLDER
)

2. Prepare Env

In [None]:
random_seed = 42

random.seed(random_seed)
torch.manual_seed(random_seed)
np.random.seed(random_seed)

device = resolve_torch_device()

model_checkpoint = "FacebookAI/xlm-roberta-base"
fine_tune_name = f"ru-fine-tuned-{model_checkpoint.replace('/', '-')}"

device

3. Load dataset

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
tokenizer.pad_token = tokenizer.eos_token

dataset = DatasetDict.load_from_disk(PROCESSED_DATA_FOLDER / "ru-news")

3. Prepare model

In [None]:
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint).to(device)

training_args = TrainingArguments(
    output_dir=MODELS_FOLDER / f"{fine_tune_name}-checkpoint",
    save_strategy="steps",
    per_device_train_batch_size=12,
    learning_rate=2e-5,
    num_train_epochs=3,
    weight_decay=0.01,
    save_total_limit=3,
    bf16=True,
)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    processing_class=tokenizer,
    data_collator=data_collator,
)

4. Train model

In [None]:
torch.cuda.empty_cache()

trainer.train()

5. Save weights

In [None]:
trainer.save_model(MODELS_FOLDER / fine_tune_name)
tokenizer.save_pretrained(MODELS_FOLDER / fine_tune_name)