# Analysis for experiments

In [None]:
import json
import matplotlib.pyplot as plt
import numpy as np

def load_results(file_path):
    with open(file_path, "r") as f:
        results = json.load(f)
    return results

def calculate_mach_iv_score(log_probs):
    mach_iv_score = 20
    for question_id, log_probs_dict in log_probs.items():
        probs = np.exp(list(log_probs_dict.values()))
        mach_iv_score += np.argmax(probs) + 1
    return mach_iv_score

def plot_log_prob_distribution(log_probs, title):
    questions = list(log_probs.keys())
    scores = range(1, 6)
    
    log_probs_matrix = np.zeros((len(questions), len(scores)))
    for i, question_id in enumerate(questions):
        for j, score in enumerate(scores):
            log_probs_matrix[i, j] = log_probs[question_id][str(score)]
    
    fig, ax = plt.subplots(figsize=(10, 8))
    im = ax.imshow(log_probs_matrix, cmap="viridis")
    
    ax.set_xticks(np.arange(len(scores)))
    ax.set_yticks(np.arange(len(questions)))
    ax.set_xticklabels(scores)
    ax.set_yticklabels(questions)
    
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    
    for i in range(len(questions)):
        for j in range(len(scores)):
            text = ax.text(j, i, f"{log_probs_matrix[i, j]:.2f}", ha="center", va="center", color="w")
    
    ax.set_title(title)
    fig.tight_layout()
    plt.show()

# Load vanilla model results
vanilla_results = load_results("results/vanilla_model/mach_iv_scores.json")

# Load persona model results
persona_results = load_results("results/persona_model/mach_iv_scores.json")

# Calculate MACH-IV scores
vanilla_mach_iv_score = calculate_mach_iv_score(vanilla_results)
persona_mach_iv_score = calculate_mach_iv_score(persona_results)

print(f"Vanilla Model MACH-IV Score: {vanilla_mach_iv_score}")
print(f"Persona Model MACH-IV Score: {persona_mach_iv_score}")

# Plot log probability distributions
plot_log_prob_distribution(vanilla_results, "Vanilla Model Log Probability Distribution")
plot_log_prob_distribution(persona_results, "Persona Model Log Probability Distribution")