# Ablation Study Analysis

Compare timeout occurrences across datasets, classes, and ablation setups.

In [1]:
import json
from pathlib import Path
import pandas as pd

CHECKPOINTS_PATH = Path('results/res_ablation')
def scan_experiments():
    """Scan checkpoint directory and collect timeout statistics."""
    results = []

    for redis_file in CHECKPOINTS_PATH.rglob('redis_dump_readable.json'):
        # Parse path: CHECKPOINTS_PATH/<dataset>/workers_16/class_<label>/<setup>/redis_dump_readable.json
        parts = redis_file.relative_to(CHECKPOINTS_PATH).parts
        if len(parts) < 4:
            continue

        dataset = parts[0]
        class_label = parts[2].replace('class_', '')
        setup = parts[3]  # 'all_features' or 'ablation_R<>_NR<>_GP<>_BP<>'

        # Read the JSON file
        with open(redis_file, 'r') as f:
            data = json.load(f)

        # Get the "0" key (database dump)
        db_dump = data.get('0', {})

        # Collect stats
        total_keys = 0
        timeout_count = 0
        times_all = []
        times_no_timeout = []

        for key, value in db_dump.items():
            if isinstance(value, dict) and 'timeout_occurred' in value:
                total_keys += 1
                is_timeout = value['timeout_occurred'] == 'True'
                if is_timeout:
                    timeout_count += 1

                # Compute time (vte - vts)
                if 'vts' in value and 'vte' in value:
                    try:
                        vts = float(value['vts'])
                        vte = float(value['vte'])
                        elapsed = vte - vts
                        times_all.append(elapsed)
                        if not is_timeout:
                            times_no_timeout.append(elapsed)
                    except (ValueError, TypeError):
                        pass

        avg_time = sum(times_all) / len(times_all) if times_all else 0
        avg_time_no_timeout = sum(times_no_timeout) / len(times_no_timeout) if times_no_timeout else 0

        results.append({
            'dataset': dataset,
            'class': class_label,
            'setup': setup,
            'timeouts': timeout_count,
            'total_keys': total_keys,
            'avg_time': round(avg_time, 2),
            'avg_time_no_timeout': round(avg_time_no_timeout, 2),
            '_times_all': times_all,  # keep raw for aggregation
            '_times_no_timeout': times_no_timeout
        })

    return pd.DataFrame(results)

df = scan_experiments()
print(f"Found {len(df)} experiment configurations")
# Display raw data with percentages (hide internal columns)
df['timeout_pct'] = (df['timeouts'] / df['total_keys'] * 100).round(2)
df = df.sort_values(['dataset', 'class', 'setup'])

display_cols = ['dataset', 'class', 'setup', 'timeouts', 'total_keys', 'timeout_pct', 'avg_time', 'avg_time_no_timeout']
display(df[display_cols])

Found 156 experiment configurations


Unnamed: 0,dataset,class,setup,timeouts,total_keys,timeout_pct,avg_time,avg_time_no_timeout
150,ann-thyroid,1.0,ablation_R0_NR0_GP1_BP1,0,16,0.00,69.81,69.81
151,ann-thyroid,1.0,ablation_R0_NR1_GP1_BP1,1,19,5.26,113.02,98.77
152,ann-thyroid,1.0,ablation_R1_NR0_GP1_BP1,0,19,0.00,84.73,84.73
153,ann-thyroid,1.0,ablation_R1_NR1_GP0_BP1,1,19,5.26,62.74,44.64
154,ann-thyroid,1.0,ablation_R1_NR1_GP1_BP0,0,19,0.00,94.42,94.42
...,...,...,...,...,...,...,...,...
0,karhunen,9.0,ablation_R0_NR1_GP1_BP1,24,24,100.00,575.74,0.00
1,karhunen,9.0,ablation_R1_NR0_GP1_BP1,24,24,100.00,818.79,0.00
2,karhunen,9.0,ablation_R1_NR1_GP0_BP1,20,24,83.33,527.59,229.07
3,karhunen,9.0,ablation_R1_NR1_GP1_BP0,24,24,100.00,672.70,0.00


In [2]:

# Aggregated by dataset and setup (recompute averages from raw times)
def aggregate_group(group):
    all_times = [t for times in group['_times_all'] for t in times]
    no_timeout_times = [t for times in group['_times_no_timeout'] for t in times]
    return pd.Series({
        'timeouts': group['timeouts'].sum(),
        'total_keys': group['total_keys'].sum(),
        'avg_time': round(sum(all_times) / len(all_times), 2) if all_times else 0,
        'avg_time_no_timeout': round(sum(no_timeout_times) / len(no_timeout_times), 2) if no_timeout_times else 0
    })

agg = df.groupby(['dataset', 'setup']).apply(aggregate_group).reset_index()
agg['timeout_pct'] = (agg['timeouts'] / agg['total_keys'] * 100).round(2)
agg = agg[['dataset', 'setup', 'timeouts', 'total_keys', 'timeout_pct', 'avg_time', 'avg_time_no_timeout']]
print("\nAggregated by dataset, setup:")
display(agg)


Aggregated by dataset, setup:


Unnamed: 0,dataset,setup,timeouts,total_keys,timeout_pct,avg_time,avg_time_no_timeout
0,ann-thyroid,ablation_R0_NR0_GP1_BP1,0.0,16.0,0.0,69.81,69.81
1,ann-thyroid,ablation_R0_NR1_GP1_BP1,32.0,83.0,38.55,216.51,62.25
2,ann-thyroid,ablation_R1_NR0_GP1_BP1,9.0,83.0,10.84,97.35,58.31
3,ann-thyroid,ablation_R1_NR1_GP0_BP1,11.0,83.0,13.25,94.27,46.29
4,ann-thyroid,ablation_R1_NR1_GP1_BP0,8.0,83.0,9.64,103.39,65.36
5,ann-thyroid,all_features,9.0,83.0,10.84,107.94,65.11
6,appendicitis,ablation_R0_NR1_GP1_BP1,0.0,40.0,0.0,7.11,7.11
7,appendicitis,ablation_R1_NR0_GP1_BP1,0.0,40.0,0.0,6.13,6.13
8,appendicitis,ablation_R1_NR1_GP0_BP1,0.0,40.0,0.0,5.2,5.2
9,appendicitis,ablation_R1_NR1_GP1_BP0,0.0,40.0,0.0,6.09,6.09


In [3]:
import plotly.graph_objects as go
import numpy as np

# Map setup names to deactivated features
setup_labels = {
    'ablation_R0_NR1_GP1_BP1': 'R',
    'ablation_R1_NR0_GP1_BP1': 'NR',
    'ablation_R1_NR1_GP0_BP1': 'GP',
    'ablation_R1_NR1_GP1_BP0': 'BP',
    'all_features': 'None'
}
# Calculate no timeout percentage
agg['no_timeout_pct'] = 100 - agg['timeout_pct']

# Pivot tables
pivot_time = agg.pivot(index='dataset', columns='setup', values='avg_time_no_timeout')
pivot_pct = agg.pivot(index='dataset', columns='setup', values='no_timeout_pct')

# Rename columns to show deactivated features
pivot_time.columns = [setup_labels.get(col, col) for col in pivot_time.columns]
pivot_pct.columns = [setup_labels.get(col, col) for col in pivot_pct.columns]

# Create combined text: "time (pct%)"
combined_text = []
for i in range(len(pivot_time.index)):
    row = []
    for j in range(len(pivot_time.columns)):
        time_val = pivot_time.values[i, j]
        pct_val = pivot_pct.values[i, j]
        if np.isnan(time_val):
            row.append('')
        else:
            row.append(f'{time_val:.1f} ({pct_val:.0f}%)')
    combined_text.append(row)

fig = go.Figure(data=go.Heatmap(
    z=pivot_time.values,
    x=pivot_time.columns,
    y=pivot_time.index,
    text=combined_text,
    texttemplate='%{text}',
    textfont={'size': 16},
    colorscale='Blues',
    colorbar=dict(title='Avg Time (s)')
))

fig.update_layout(
    height=500,
    width=900,
    xaxis_title="Deactivated Feature",
    yaxis_title="Dataset name",
)
fig.write_image("ablation_heatmap_combined.pdf")
fig.write_image("ablation_heatmap_combined.png")
fig.show()