In [1]:
import logging
import sys
from functools import partial
from pathlib import Path

from datasets import load_dataset
from omegaconf import OmegaConf
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments
import pandas as pd

dataset_name = "kamel-usp/aes_enem_dataset"
dataset_split = "JBCS2025"
cache_dir = "/tmp/"
model_type = "encoder_classification"
model_name = "google-bert/bert-base-multilingual-cased"
fine_tuned_model_id = "kamel-usp/jbcs2025_mbert_base-C1"
grade_index = 0

In [2]:
parent_dir = str(Path(".").resolve().parent) + "/scripts"
sys.path.append(str(parent_dir))

In [3]:
from preprocess import load_tokenizer, tokenize_dataset
from metrics.metrics import compute_metrics

In [4]:
dataset = load_dataset(
    dataset_name,
    dataset_split,
    cache_dir=cache_dir,
)

# Load the tokenizer
tokenizer = load_tokenizer(
    model_type,
    model_name,
    cache_dir=cache_dir,
)

tokenized_dataset = tokenize_dataset(
    dataset,
    tokenizer,
    text_column="essay_text",
    grade_index=grade_index,
    model_type=model_type,
    logger=logging.getLogger(),
)

In [5]:
model = AutoModelForSequenceClassification.from_pretrained(
    fine_tuned_model_id, cache_dir=cache_dir
)

config.json:   0%|          | 0.00/1.11k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/711M [00:00<?, ?B/s]

In [6]:
# 1. Create a dummy config (mimicking the structure your code expects)
experiment_config = OmegaConf.create(
    {
        "experiments": {
            "model": {
                "type": model_type
            }
        }
    }
)

In [7]:
compute_metrics_partial = partial(compute_metrics, cfg=experiment_config)

training_args = TrainingArguments(
    output_dir="test_trainer",
    do_eval=True,
    per_device_eval_batch_size=16,
)

trainer = Trainer(
    model=model,
    args=training_args,
    eval_dataset=tokenized_dataset["test"],  # or whichever split you want
    compute_metrics=compute_metrics_partial
)

In [8]:
eval_results = trainer.evaluate()
pd.DataFrame.from_dict(eval_results, orient="index").T

Unnamed: 0,eval_loss,eval_model_preparation_time,eval_accuracy,eval_RMSE,eval_QWK,eval_HDIV,eval_Macro F1,eval_runtime,eval_samples_per_second,eval_steps_per_second
0,0.954227,0.0058,0.536232,30.072376,0.450592,0.007246,0.324464,1.1658,118.378,7.72
