In [None]:
from itertools import chain
from pathlib import Path

import evaluate
from datasets import concatenate_datasets
from transformers import IntervalStrategy

from utils.dataset_utils import get_politicalness_datasets, \
    get_politicalness_datasets_from_leaning_datasets_for_leave_one_out_benchmark, \
    leaning_with_center_label_mapping
from utils.model_utils import finetune_models

In [None]:
TRAINING_POLITICAL_LEANING = False
RESULT_SUBDIRECTORY_NAME = "politicalness"
GET_DATASETS = lambda: chain(
    get_politicalness_datasets(),
    get_politicalness_datasets_from_leaning_datasets_for_leave_one_out_benchmark()
)
TRAIN_DATASET_SAMPLE_SIZE = 1_000
EVAL_DATASET_SAMPLE_SIZE = 100

train_datasets_separate = []
eval_datasets = []
for dataset in GET_DATASETS():
    eval_dataset = dataset.take_even_class_distribution_sample(EVAL_DATASET_SAMPLE_SIZE)
    eval_dataset = eval_dataset.transform_for_inference(
        leaning_with_center_label_mapping if TRAINING_POLITICAL_LEANING else None
    )
    eval_datasets.append(eval_dataset.to_huggingface())

    # Remove the eval sample from the source dataframe.
    dataset.dataframe = dataset.dataframe.loc[~dataset.dataframe.index.isin(eval_dataset.dataframe.index)]

    train_dataset = dataset.take_even_class_distribution_sample(TRAIN_DATASET_SAMPLE_SIZE)
    train_dataset = train_dataset.transform_for_inference(
        leaning_with_center_label_mapping if TRAINING_POLITICAL_LEANING else None
    )
    train_datasets_separate.append(train_dataset.to_huggingface())

train_datasets = [
    concatenate_datasets(
        list(filter(
            lambda dataset: dataset.info.dataset_name != left_out_dataset.info.dataset_name, train_datasets_separate
        )),
        info=left_out_dataset.info
    ) for left_out_dataset in train_datasets_separate
]

In [None]:
TRAINING_SEED = 37
DATA_SEED = 37
EVAL_STRATEGY = IntervalStrategy.EPOCH

finetune_models(
    Path("dataset_benchmark", "leave_one_out", RESULT_SUBDIRECTORY_NAME),
    train_datasets,
    eval_datasets,
    EVAL_STRATEGY,
    TRAINING_SEED,
    DATA_SEED
)