Experiment resources related to the QUITE corpus (EMNLP 2024).

Copyright (c) 2024 Robert Bosch GmbH

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published
by the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program.  If not, see <https://www.gnu.org/licenses/>.

In [None]:
from datasets import Dataset, Split
import re
from os import listdir
from os.path import join

from src.constants import QUITE_Config, FLOAT_PERCENTAGE_PATTERN, PERCENTAGE_PATTERN, INVALID_INDICATION_FLAG, PROJECT_ROOT
from src.experiments.src.numeric_evaluator import NumericEvaluator
from src.utils.quite_dataset_loaders import quite_dataset_mappings

ne: NumericEvaluator = NumericEvaluator()
dataset: Dataset = quite_dataset_mappings[Split.TEST][QUITE_Config.EVIDENCE_QUERY_PAIRS.value]

In [None]:
def calculate_metrics(input_paths: dict[str, dict[str, str]], dataset: Dataset) -> dict[str, dict[str, dict]] | dict[str, dict[str, dict]]:

    result_dict: dict[str, dict[str, dict]] = {}
    preds_per_network: dict[str, dict[str, dict]] = {}

    for prompt in input_paths.keys():
        result_dict[prompt] = {}
        preds_per_network[prompt] = {}
        for model, path in input_paths[prompt].items():
            if path == "":
                continue
            non_error_instances: list[str] = []
            predicted_probs: list[float] = []
            ground_truth_probs: list[float] = []
            reasoning_types: list[list[str]] = []
            preds_per_network[prompt][model] = {}
            for file in listdir(path):
                if not "hepar2_1" in file:
                    network_name: str = file.split("_")[0]
                    qe_id: int = int(file.split("_")[1][:-4])
                else:
                    network_name: str = "hepar2_1"
                    qe_id: int = int(file[len(network_name)+1:-4])
                with open(join(path, file)) as f:
                    content: str = f.read()

                non_error_instances.append(file)


                subset = dataset.filter(lambda x: x["qe_id"] == qe_id);

                answer: float = float(subset[-1]["answer"])
                ground_truth_probs.append(answer)
                reasoning_types.append(subset[-1]["reasoning_types"])

                assert subset[-1]["input"] in content

                final_output: str = content[-350:]

                # First look for percentage
                percentages: list[str] = re.findall(PERCENTAGE_PATTERN, final_output)
                if len(percentages) > 0:
                    normalized_float_value: float = float(percentages[-1].rstrip("%")) / 100
                    predicted_probs.append(normalized_float_value)
                elif len(re.findall(FLOAT_PERCENTAGE_PATTERN, final_output)) > 0: # Find last float number
                    normalized_float_value: float = float(re.findall(FLOAT_PERCENTAGE_PATTERN, final_output)[-1])
                    predicted_probs.append(normalized_float_value)
                else: # Insert default value
                    predicted_probs.append(INVALID_INDICATION_FLAG)
                    non_error_instances = non_error_instances[:-1]
                
                if network_name not in preds_per_network[prompt][model]:
                    preds_per_network[prompt][model][network_name] = {"predicted": [], "gt": [], "rt": []}
                preds_per_network[prompt][model][network_name]["predicted"].append(normalized_float_value)
                preds_per_network[prompt][model][network_name]["gt"].append(answer)
                preds_per_network[prompt][model][network_name]["rt"].append(subset[-1]["reasoning_types"])

            if model == "gpt-4-turbo" and prompt == "causalcot":
                for network_name in preds_per_network[prompt][model].keys():
                    network_results = ne.get_metrics(predicted_probs=preds_per_network[prompt][model][network_name]["predicted"], true_probs=preds_per_network[prompt][model][network_name]["gt"], reasoning_types=preds_per_network[prompt][model][network_name]["rt"])
                    preds_per_network[prompt][model][network_name]["accuracy"] = network_results["accuracy"]
                    preds_per_network[prompt][model][network_name]["wrong"] = network_results["wrong"]
                    preds_per_network[prompt][model][network_name]["error"] = network_results["error"]
                print(preds_per_network[prompt][model])

            result_dict[prompt][model] = ne.get_metrics(predicted_probs=predicted_probs, true_probs=ground_truth_probs, reasoning_types=reasoning_types)

    return result_dict, preds_per_network



# Numeric

In [None]:
numeric_input_paths: dict[str, dict[str, str]] = {
    "zero_shot": {
        "llama-3-8b": join(PROJECT_ROOT, "paper_results/prompting/zero-shot/numeric/llama-3-8b"),
        "mixtral-8x7b": join(PROJECT_ROOT, "paper_results/prompting/zero-shot/numeric/mixtral-8x7b"),
        "gpt-4-turbo": join(PROJECT_ROOT, "paper_results/prompting/zero-shot/numeric/gpt4-turbo"),
    },
    "causalcot": {
        "llama-3-8b": join(PROJECT_ROOT, "paper_results/prompting/causalcot/numeric/llama-3-8b"),
        "mixtral-8x7b": join(PROJECT_ROOT, "paper_results/prompting/causalcot/numeric/mixtral-8x7b"),
        "gpt-4-turbo": join(PROJECT_ROOT, "paper_results/prompting/causalcot/numeric/gpt4-turbo"),
    },
    "qe_only": {
        "gpt-4-turbo": join(PROJECT_ROOT, "paper_results/prompting/qe-only/gpt4-turbo"),
    }
}

result_dict, preds_per_network = calculate_metrics(input_paths=numeric_input_paths, dataset=dataset)

for k, v in result_dict.items():
    for k2, v2 in v.items():
        print(k, ", ", k2, ": ", v2, "\n\n\n")


            


# WEP

In [None]:
wep_input_paths: dict[str, dict[str, str]] = {
    "zero_shot": {
        "llama-3-8b": join(PROJECT_ROOT, "paper_results/prompting/zero-shot/wep-based-premises/llama-3-8b"),
        "mixtral-8x7b": join(PROJECT_ROOT, "paper_results/prompting/zero-shot/wep-based-premises/mixtral-8x7b"),
        "gpt-4-turbo": join(PROJECT_ROOT, "paper_results/prompting/zero-shot/wep-based-premises/gpt4-turbo"),
    },
    "causalcot": {
        "llama-3-8b": join(PROJECT_ROOT, "paper_results/prompting/causalcot/wep-based-premises/llama-3-8b"),
        "mixtral-8x7b": join(PROJECT_ROOT, "paper_results/prompting/causalcot/wep-based-premises/mixtral-8x7b"),
        "gpt-4-turbo": join(PROJECT_ROOT, "paper_results/prompting/causalcot/wep-based-premises/gpt4-turbo"),
    }
}

result_dict, preds_per_network = calculate_metrics(input_paths=wep_input_paths, dataset=dataset)

for k, v in result_dict.items():
    for k2, v2 in v.items():
        print(k, ", ", k2, ": ", v2, "\n\n\n")