# Fine tune base model

1. Import dependencies

In [1]:
import random

import numpy as np
import torch

from transformers import (
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    AutoModelForMaskedLM,
    TrainingArguments,
    Trainer,
)
from datasets import DatasetDict

from src.util.torch_device import resolve_torch_device
from src.definitions import (
    MODELS_FOLDER,
    PROCESSED_DATA_FOLDER
)

2. Prepare Env

In [2]:
random_seed = 42

random.seed(random_seed)
torch.manual_seed(random_seed)
np.random.seed(random_seed)

device = resolve_torch_device()

model_checkpoint = "models/train-test-fine-tuned-models-ru-fine-tuned-FacebookAI-xlm-roberta-base-3-checkpoint/checkpoint-500"
fine_tune_name = (
    f"train-test-fine-tuned-models-ru-fine-tuned-FacebookAI-xlm-roberta-base-4"
)

device

device(type='cuda')

3. Load dataset

In [3]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
tokenizer.pad_token = tokenizer.eos_token

dataset = DatasetDict.load_from_disk(PROCESSED_DATA_FOLDER / "train-test")

3. Prepare model

In [4]:
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint).to(device)

training_args = TrainingArguments(
    output_dir=MODELS_FOLDER / f"{fine_tune_name}-checkpoint",
    save_strategy="steps",
    per_device_train_batch_size=12,
    learning_rate=2e-7,
    num_train_epochs=1,
    weight_decay=0.01,
    save_total_limit=3,
    bf16=True,
)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    processing_class=tokenizer,
    data_collator=data_collator,
)

4. Train model

In [5]:
torch.cuda.empty_cache()

trainer.train()

Step,Training Loss
500,0.5362


TrainOutput(global_step=797, training_loss=0.6716426145776153, metrics={'train_runtime': 2365.6163, 'train_samples_per_second': 4.04, 'train_steps_per_second': 0.337, 'total_flos': 1982130101280540.0, 'train_loss': 0.6716426145776153, 'epoch': 1.0})

5. Save weights

In [6]:
trainer.save_model(MODELS_FOLDER / fine_tune_name)
tokenizer.save_pretrained(MODELS_FOLDER / fine_tune_name)

('/home/melal/Workspace/unlp-2025-manipulation-detector/models/train-test-fine-tuned-models-ru-fine-tuned-FacebookAI-xlm-roberta-base-4/tokenizer_config.json',
 '/home/melal/Workspace/unlp-2025-manipulation-detector/models/train-test-fine-tuned-models-ru-fine-tuned-FacebookAI-xlm-roberta-base-4/special_tokens_map.json',
 '/home/melal/Workspace/unlp-2025-manipulation-detector/models/train-test-fine-tuned-models-ru-fine-tuned-FacebookAI-xlm-roberta-base-4/tokenizer.json')