In [None]:
import sys
sys.path.append('..')

import matplotlib.pyplot as plt
import seaborn as sns
from experiments.ablation_chain_length import run_ablation_chain_length
from experiments.ablation_threshold import run_ablation_threshold
import pandas as pd

# Set style
sns.set_style("whitegrid")

In [None]:
# Run chain length ablation
chain_results = run_ablation_chain_length("../data/hotpotqa/hotpot_dev_distractor_v1.json")

In [None]:
# Plot chain length results
df_chain = pd.DataFrame.from_dict(chain_results, orient='index')
df_chain.reset_index(inplace=True)
df_chain.rename(columns={'index': 'Chain Length'}, inplace=True)

fig, ax = plt.subplots(1, 2, figsize=(12, 5))

sns.barplot(data=df_chain, x='Chain Length', y='F1', ax=ax[0])
ax[0].set_title('F1 Score vs Chain Length')

sns.barplot(data=df_chain, x='Chain Length', y='EM', ax=ax[1])
ax[1].set_title('Exact Match vs Chain Length')

plt.tight_layout()
plt.show()

In [None]:
# Run threshold ablation
threshold_results = run_ablation_threshold("../data/hotpotqa/hotpot_dev_distractor_v1.json")

In [None]:
# Plot threshold results
df_thresh = pd.DataFrame.from_dict(threshold_results, orient='index')
df_thresh.reset_index(inplace=True)
df_thresh.rename(columns={'index': 'Threshold'}, inplace=True)

fig, ax = plt.subplots(1, 3, figsize=(18, 5))

sns.lineplot(data=df_thresh, x='Threshold', y='F1', ax=ax[0], marker='o')
ax[0].set_title('F1 Score vs Entailment Threshold')

sns.lineplot(data=df_thresh, x='Threshold', y='EM', ax=ax[1], marker='o')
ax[1].set_title('Exact Match vs Entailment Threshold')

sns.lineplot(data=df_thresh, x='Threshold', y='Chains Kept', ax=ax[2], marker='o')
ax[2].set_title('Fraction of Chains Kept vs Threshold')

plt.tight_layout()
plt.show()