In [None]:
from transformers import (
    AutoAdapterModel,
    TrainingArguments, 
    AdapterTrainer,
    AutoTokenizer,
)
from datasets import (
    Dataset,
    DatasetDict
)

import numpy as np
import pandas as pd
import evaluate
import torch
import gc
import os

In [None]:
DATA_PATH = "../../data/processed"
MODEL_PATH = "../../models"

dimensions = [
    "gender", 
    "race", 
    "religion", 
    "nationality", 
    "country",
    "merged"
]

tasks = [
    "buzzfeed",
    "politifact",
    "twittercovidq2",
    "clef22",
    "propaganda",
    "webis",
    "pheme",
    "basil",
    "shadesoftruth",
    "fingerprints",
    "clickbait",
]

testset_state = [
    "unperturbed",
    "perturbed"
]

models_name = [
    "bert-base-cased",
    "roberta-base",
    "distilbert-base-cased",
    "microsoft/deberta-base",
    "facebook/FairBERTa",
]

model_finetunes = [
    "unperturbed",
    "perturbed"    
]

In [None]:
for dimension in dimensions:
    for task in tasks:
        for state in testset_state:
            for model_name in models_name:
                for model_finetune in model_finetunes:
                    # DIMENSION
                    if model_finetune == "unperturbed":
                        model_folder_path = f"../..{os.sep}models{os.sep}{model_name}{os.sep}vanilla{os.sep}{task}"
                    else:
                        model_folder_path = f"../..{os.sep}models{os.sep}{model_name}{os.sep}{dimension}{os.sep}{task}"
                    
                    CONFIG = {
                        "task_name": task,
                        "model_name": model_name,
                        "model_path": f"{model_folder_path}{os.sep}{os.listdir(model_folder_path)[0]}{os.sep}{task}",
                        "max_length": 128,
                    }

                    dataset_path = f"{DATA_PATH}{os.sep}{dimension}"

                    test_df = pd.read_csv(f"{dataset_path}{os.sep}{CONFIG['task_name']}_test.csv")
                    
                    if state == "perturbed":
                        test_df = test_df.drop(columns=['text'])
                        test_df.rename(columns = {'perturbed_text':'text'}, inplace = True)

                    test = Dataset.from_pandas(test_df)

                    dataset = DatasetDict({"test": test})
                    dataset = dataset.class_encode_column("labels")

                    tokenizer = AutoTokenizer.from_pretrained(CONFIG['model_name'])

                    def tokenize_function(examples):
                        return tokenizer(
                            examples["text"], padding="max_length", truncation=True, max_length=CONFIG["max_length"]
                        )

                    tokenized_datasets = dataset.map(tokenize_function, batched=True)
                    test_dataset = tokenized_datasets["test"]

                    model = AutoAdapterModel.from_pretrained(CONFIG['model_name'])
                    model.load_adapter(CONFIG['model_path'])
                    model.set_active_adapters(task)

                    f1_metric = evaluate.load("f1")
                    recall_metric = evaluate.load("accuracy")

                    def compute_metrics(eval_pred):
                        logits, labels = eval_pred
                        preds = np.argmax(logits, axis=-1)
                        results = {}
                        results.update(f1_metric.compute(predictions=preds, references=labels, average="macro"))
                        results.update(recall_metric.compute(predictions=preds, references=labels))
                        return results

                    training_args = TrainingArguments(
                        output_dir="evaluation",
                    )

                    trainer = AdapterTrainer(
                        model=model,
                        args=training_args,
                        compute_metrics=compute_metrics,
                    )

                    trainer.evaluate(test_dataset, metric_key_prefix="test")

                    del model
                    gc.collect()
                    torch.cuda.empty_cache()