In [None]:
from inspect_ai.log import read_eval_log
import pandas as pd
import os
import matplotlib.pyplot as plt
import seaborn as sns

def load_scores(path):
    # return nhops, answer_same, original_answer, continuation_answer, target
    logs = read_eval_log(path)
    scores = []
    for sample in logs.samples:
        metadata = sample.metadata
        score_key = list(sample.scores.keys())[0]
        score = sample.scores[score_key]
        if score.metadata is not None:
            scores.append((metadata['num_hops'], score.value, score.metadata['original_answer'],
                         score.metadata['continuation_answer'], score.metadata['target']))
        else:
            scores.append((metadata['num_hops'], score.value, None, None, None))

    df = pd.DataFrame(scores, columns=['nhops', 'answer_same', 'original_answer', 'continuation_answer', 'target'])
    
    # Add truncate_frac from task args
    df['truncate_frac'] = logs.eval.task_args['truncate_frac']
    
    return df

# Load and combine all evaluation logs
logs_dir = 'paraphrase_faithfulness_logs'  # Update this to your log directory
all_scores = []
for log in os.listdir(logs_dir):
    if log.endswith('.eval'):
        all_scores.append(load_scores(os.path.join(logs_dir, log)))

all_scores = pd.concat(all_scores)

# Plot 1: Faithfulness vs number of hops for different truncation fractions
plt.figure(figsize=(10, 6))
sns.set_theme(style="whitegrid")
sns.lineplot(x="nhops", y="answer_same", hue="truncate_frac", data=all_scores, errorbar="se")
plt.title("Faithfulness vs Chain Length")
plt.xlabel("Number of Hops")
plt.ylabel("Faithfulness Score")
plt.show()

# Plot 2: Faithfulness vs truncation fraction for different hop counts
plt.figure(figsize=(10, 6))
sns.lineplot(x="truncate_frac", y="answer_same", hue="nhops", data=all_scores, errorbar="se")
plt.title("Faithfulness vs Truncation Fraction")
plt.xlabel("Truncation Fraction")
plt.ylabel("Faithfulness Score")
plt.show()

# Optional: Add summary statistics
print("\nSummary Statistics:")
print(all_scores.groupby(['nhops', 'truncate_frac'])['answer_same'].agg(['mean', 'std', 'count']))