In [1]:
from collections import Counter
import pickle

from sklearn.metrics import f1_score, accuracy_score

In [None]:
# Load data_list from a file
def load_data_list(file_name):
    with open(file_name, 'rb') as file:
        data_list = pickle.load(file)
    print(f"data_list loaded from {file_name}")
    return data_list

# Example usage
file_name = "../llm_evaluation.pkl"
data_list = load_data_list(file_name)

data_list loaded from ./llm_evaluation.pkl


In [3]:
from sklearn.metrics import precision_score, recall_score
from collections import defaultdict

def custom_accuracy_score(y_true, y_pred):
    """
    Custom accuracy function that considers predictions correct if they contain the same characters,
    regardless of their order.

    Args:
        y_true (list): List of true labels.
        y_pred (list): List of predicted labels.

    Returns:
        float: Custom accuracy score.
    """
    correct = sum(sorted(true) == sorted(pred) for true, pred in zip(y_true, y_pred))
    return correct / len(y_true)

def custom_f1_score(y_true, y_pred):
    """
    Custom F1 score function that considers predictions correct if they contain the same characters,
    regardless of their order.

    Args:
        y_true (list): List of true labels.
        y_pred (list): List of predicted labels.

    Returns:
        float: Custom F1 score.
    """
    y_true_sorted = [''.join(sorted(true)) for true in y_true]
    y_pred_sorted = [''.join(sorted(pred)) for pred in y_pred]
    precision = precision_score(y_true_sorted, y_pred_sorted, average='weighted', zero_division=0)
    recall = recall_score(y_true_sorted, y_pred_sorted, average='weighted', zero_division=0)
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    return f1

def compute_accuracy(data_list, vote_key):
    """
    Compute the accuracy of a specific voting method using custom accuracy score.

    Args:
        data_list (dict): A dictionary containing questions, answers, and voting results.
        vote_key (str): The key in each entry representing the voting result to evaluate.

    Returns:
        float: The custom accuracy of the voting method.
    """
    true_labels = [entry["answer"] for entry in data_list.values()]
    predicted_labels = [entry[vote_key] for entry in data_list.values()]
    return custom_accuracy_score(true_labels, predicted_labels)

def compute_grouped_accuracy(data_list, vote_key, group_key):
    """
    Compute the accuracy of a specific voting method grouped by a key using custom accuracy score.

    Args:
        data_list (dict): A dictionary containing questions, answers, and voting results.
        vote_key (str): The key in each entry representing the voting result to evaluate.
        group_key (str): The key in each entry to group by (e.g., 'domain' or 'q_type').

    Returns:
        dict: A dictionary where keys are group values and values are the custom accuracy for that group.
    """
    grouped_data = defaultdict(list)
    
    # Group data by the specified key
    for entry in data_list.values():
        grouped_data[entry[group_key]].append(entry)
    
    # Compute custom accuracy for each group
    grouped_accuracy = {}
    for group, entries in grouped_data.items():
        true_labels = [entry["answer"] for entry in entries]
        predicted_labels = [entry[vote_key] for entry in entries]
        grouped_accuracy[group] = custom_accuracy_score(true_labels, predicted_labels)
    
    return grouped_accuracy

aggregate_models = ["chatglm3-6b-chat", "qwen2.5-7b-instruct", "baichuan2-7b-chat", "deepseek-v2-lite-chat", "hunyuan"]

# Filter data_list to include only aggregate_models
filtered_data_list = {}
for key, entry in data_list.items():
    filtered_results = [
        {model: result[model]} for result in entry["results"] for model in result if model in aggregate_models
    ]
    filtered_entry = entry.copy()
    filtered_entry["results"] = filtered_results
    filtered_data_list[key] = filtered_entry
    
def compute_f1_by_group(data_list, vote_key, group_key):
    """
    Compute custom F1 scores for a specific voting method grouped by a key.

    Args:
        data_list (dict): A dictionary containing questions, answers, and voting results.
        vote_key (str): The key in each entry representing the voting result to evaluate.
        group_key (str): The key in each entry to group by (e.g., 'domain' or 'q_type').

    Returns:
        dict: A dictionary where keys are group values and values are the custom F1 scores for that group.
    """
    grouped_data = defaultdict(list)
    
    # Group data by the specified key
    for entry in data_list.values():
        grouped_data[entry[group_key]].append(entry)
    
    # Compute custom F1 scores for each group
    f1_scores = {}
    for group, entries in grouped_data.items():
        true_labels = [entry["answer"] for entry in entries]
        predicted_labels = [entry[vote_key] for entry in entries]
        f1_scores[group] = custom_f1_score(true_labels, predicted_labels)
    
    return f1_scores

def compute_accuracy_based_weights(data_list, models):
    """
    Compute accuracy-based weights for models based on their agreement with the ground truth using custom comparison.

    Args:
        data_list (dict): A dictionary containing questions, answers, and model results.
        models (list): A list of model names.

    Returns:
        dict: A dictionary where keys are model names and values are their accuracy-based weights.
    """
    weights = {model: 0 for model in models}
    total_labels = {model: 0 for model in models}

    for q in data_list.values():
        ground_truth = q["answer"]
        for model_result in q["results"]:
            for model, answer in model_result.items():
                total_labels[model] += 1
                if sorted(answer) == sorted(ground_truth):
                    weights[model] += 1

    # Normalize weights by dividing correct labels by total labels
    for model in models:
        weights[model] = weights[model] / total_labels[model] if total_labels[model] > 0 else 0

    return weights

In [4]:
def majority_vote(models_results):
    """
    Perform majority voting to determine the most common answer.

    Args:
        models_results (list): A list of dictionaries containing model answers.

    Returns:
        str: The answer with the highest vote count.
    """
    # Extract the answers from the models' results
    answers = [list(model.values())[0] for model in models_results]
    # Count the occurrences of each answer
    vote_counts = Counter(answers)
    # Find the answer with the highest count (majority vote)
    majority_answer = vote_counts.most_common(1)[0][0]
    return majority_answer

def weighted_majority_vote(models_results, accuracy_weights):
    """
    Perform a weighted majority vote based on accuracy weights.

    Args:
        models_results (list): A list of dictionaries containing model answers.
        accuracy_weights (dict): A dictionary with model names as keys and their accuracy weights as values.

    Returns:
        str: The answer with the highest weighted vote.
    """
    weighted_votes = Counter()

    for model_result in models_results:
        for model, answer in model_result.items():
            # Add the weighted vote for the answer
            weighted_votes[answer] += accuracy_weights.get(model, 0)

    majority_answer = weighted_votes.most_common(1)[0][0]
    return majority_answer


In [5]:
accuracy_weights = compute_accuracy_based_weights(filtered_data_list, aggregate_models)


for q_id, entry in filtered_data_list.items():
    # Perform majority voting
    majority_result = majority_vote(entry["results"])
    entry["majority_vote"] = majority_result

    # Perform weighted majority voting
    weighted_result = weighted_majority_vote(entry["results"], accuracy_weights)
    entry["weighted_majority_vote"] = weighted_result

# Compute accuracy for both methods
majority_vote_accuracy = compute_accuracy(filtered_data_list, "majority_vote")
weighted_majority_vote_accuracy = compute_accuracy(filtered_data_list, "weighted_majority_vote")

print("Majority Vote Accuracy (total):", majority_vote_accuracy)
print("Weighted Majority Vote Accuracy (total):", weighted_majority_vote_accuracy)

print("\n======== Accuracy Scores ========\n")

# Compute accuracy grouped by domain
majority_vote_accuracy_by_domain = compute_grouped_accuracy(filtered_data_list, "majority_vote", "domain")
weighted_majority_vote_accuracy_by_domain = compute_grouped_accuracy(filtered_data_list, "weighted_majority_vote", "domain")

# Compute accuracy grouped by question type
majority_vote_accuracy_by_q_type = compute_grouped_accuracy(filtered_data_list, "majority_vote", "q_type")
weighted_majority_vote_accuracy_by_q_type = compute_grouped_accuracy(filtered_data_list, "weighted_majority_vote", "q_type")
print("By Domain:\n")
print("Majority Vote:", majority_vote_accuracy_by_domain)
print("Weighted Majority Vote:", weighted_majority_vote_accuracy_by_domain)

print("\nBy Question Type:\n")
print("Majority Vote:", majority_vote_accuracy_by_q_type)
print("Weighted Majority Vote:", weighted_majority_vote_accuracy_by_q_type)

print("\n======== F1 Scores ========\n")

# Compute F1 scores grouped by domain
majority_vote_f1_by_domain = compute_f1_by_group(filtered_data_list, "majority_vote", "domain")
weighted_majority_vote_f1_by_domain = compute_f1_by_group(filtered_data_list, "weighted_majority_vote", "domain")
# Compute F1 scores grouped by question type
majority_vote_f1_by_q_type = compute_f1_by_group(filtered_data_list, "majority_vote", "q_type")
weighted_majority_vote_f1_by_q_type = compute_f1_by_group(filtered_data_list, "weighted_majority_vote", "q_type")

print("By Domain:\n")
print("Majority Vote:", majority_vote_f1_by_domain)
print("Weighted Majority Vote:", weighted_majority_vote_f1_by_domain)

print("\nBy Question Type:\n")
print("Majority Vote:", majority_vote_f1_by_q_type)
print("Weighted Majority Vote:", weighted_majority_vote_f1_by_q_type)


Majority Vote Accuracy (total): 0.5955601803676726
Weighted Majority Vote Accuracy (total): 0.6398427563880217


By Domain:

Majority Vote: {'Production Safety': 0.5697709923664123, 'Oil and Gas': 0.6097256857855362, 'Fire Safety': 0.6624068157614483, 'Civil Engineering': 0.5338613861386139, 'Economics and Finance': 0.5951075495571488, 'Banking and Insurance': 0.6757512229210343}
Weighted Majority Vote: {'Production Safety': 0.6178625954198473, 'Oil and Gas': 0.6327930174563591, 'Fire Safety': 0.6950656727014555, 'Civil Engineering': 0.5754455445544554, 'Economics and Finance': 0.649514972585407, 'Banking and Insurance': 0.7372466806429071}

By Question Type:

Majority Vote: {'Single Choice': 0.6524047138762077, 'Multiple Choice': 0.40807307012374777, 'True/False': 0.6180602006688963}
Weighted Majority Vote: {'Single Choice': 0.6961460876950845, 'Multiple Choice': 0.5123747790218032, 'True/False': 0.6180602006688963}


By Domain:

Majority Vote: {'Production Safety': 0.5960662281172608