In [None]:
from itertools import chain
from pathlib import Path

from datasets import concatenate_datasets

from utils.dataset_utils import get_politicalness_datasets, \
    get_politicalness_datasets_from_leaning_datasets, politicalness_label_mapping
from utils.existing_models.politicalness import get_existing_politicalness_models
from utils.model_utils import evaluate_models, get_custom_politicalness_models
from sklearn.metrics import ConfusionMatrixDisplay

In [None]:
GET_DATASETS = lambda: chain(
    get_politicalness_datasets(),
    get_politicalness_datasets_from_leaning_datasets()
)
DATASET_SAMPLE_SIZE = 5

datasets = list(
    map(
        lambda dataset: dataset
        .take_even_class_distribution_sample(DATASET_SAMPLE_SIZE)
        .transform_for_inference()
        .to_huggingface(),
        GET_DATASETS(),
    )
)

In [None]:
# Lambda, so that the generator can be reused.
GET_MODELS = lambda: get_existing_politicalness_models()

results = evaluate_models(
    GET_MODELS,
    datasets,
)

In [None]:
results.count

In [None]:
results.accuracy

In [None]:
results.f1

In [None]:
results.precision

In [None]:
results.recall

In [None]:
CONFUSION_MATRIX_MODEL_INDEX=0
CONFUSION_MATRIX_DATASET_INDEX=0

display = ConfusionMatrixDisplay(
    confusion_matrix=results.confusion_matrix.iloc[CONFUSION_MATRIX_MODEL_INDEX, CONFUSION_MATRIX_DATASET_INDEX],
    display_labels=list(politicalness_label_mapping.keys())
)
display.plot();

In [None]:
concatenated_dataset = concatenate_datasets(datasets)

concatenated_results = evaluate_models(
    GET_MODELS,
    [concatenated_dataset],
)

In [None]:
concatenated_results.count

In [None]:
concatenated_results.accuracy

In [None]:
concatenated_results.f1

In [None]:
concatenated_results.precision

In [None]:
concatenated_results.recall

In [None]:
CONFUSION_MATRIX_MODEL_INDEX=0
CONFUSION_MATRIX_DATASET_INDEX=0

display = ConfusionMatrixDisplay(
    confusion_matrix=concatenated_results.confusion_matrix.iloc[CONFUSION_MATRIX_MODEL_INDEX, CONFUSION_MATRIX_DATASET_INDEX],
    display_labels=list(politicalness_label_mapping.keys())
)
display.plot();