In [None]:
# !pip install pandas seaborn matplotlib
# install "font: "Palatino Linotype"

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import os
import numpy as np

In [None]:
# ==========================================
# 1. Configuration & Style Setup
# ==========================================
def setup_plotting_style():
    """
    Sets up the academic plotting style.
    """
    sns.set_theme(style="ticks", context="paper", font_scale=1.4)
    plt.rcParams.update({
        'font.family': 'serif',
        'font.serif': ['Palatino Linotype', 'Palatino', 'URW Palladio L', 'serif'],
        'axes.labelsize': 14,
        'axes.titlesize': 20,
        'xtick.labelsize': 12,
        'ytick.labelsize': 12,
        'legend.fontsize': 11,
        'lines.linewidth': 2.0,
        'lines.markersize': 8,
        'figure.dpi': 300,
        'savefig.dpi': 300,
        'axes.grid': True,
        'grid.alpha': 0.3,
        'grid.linestyle': '--',
        'legend.frameon': True,
        'legend.framealpha': 0.9,
    })


In [None]:
# ==========================================
# 2. Data Loading & Preprocessing
# ==========================================
def load_and_prep_data(filename="H1_results.csv"):
    fig_path = "fig"
    if not os.path.exists(fig_path):
        os.makedirs(fig_path)

    data_dir = "data"
    file_path = os.path.join(data_dir, filename)
    if not os.path.exists(file_path):
        file_path = filename

    # Simple check to avoid crashing if file is missing during copy-paste
    if not os.path.exists(file_path):
        print(f"Warning: {filename} not found. Ensure it is in the current directory or 'data' folder.")
        # Create a dummy df to prevent immediate crash if user is just testing code structure
        return pd.DataFrame(), fig_path

    df = pd.read_csv(file_path)

    if 'ChangeParam' in df.columns:
        df = df.rename(columns={'ChangeParam': 'Change Magnitude'})
    if 'K_Star' in df.columns:
        df = df.rename(columns={'K_Star': 'Change Point'})
        
    change_type_map = {
        'inflation': 'Homogeneous Variance Inflation',
        'ar1': 'Correlation Structure Change',
        'spike': 'Heterogeneous Variance Inflation'
    }
    if 'ChangeType' in df.columns:
        df['ChangeType'] = df['ChangeType'].replace(change_type_map)

    # Legend Mapping
    gamma_legend_map = {
        "gamma_0.0": r"$\rho_{1,0}$",
        "gamma_0.25": r"$\rho_{1,0.25}$",
        "gamma_0.45": r"$\rho_{1,0.45}$",
        "econ": r"$\rho_{2}$"
    }
    if 'WeightFunction' in df.columns:
        df['WeightFunction'] = df['WeightFunction'].replace(gamma_legend_map)

    return df, fig_path

In [None]:
# ==========================================
# 3. Generic Plotting Engine
# ==========================================
# %%
def format_tick_label(val):
    """
    Helper to format tick labels: 
    - If it's effectively an integer (e.g. 350.0), return '350'.
    - Otherwise return formatted float (e.g. 0.25).
    """
    if abs(val - round(val)) < 1e-8:
        return str(int(round(val)))
    return f'{val:.2g}'
def plot_metrics_grid(
    df, 
    x_col, 
    hue_col, 
    style_col, 
    fig_path, 
    hue_order=None, 
    palette=None, 
    markers=None,
    dashes=None,
    manual_xticks=False
):
    if df.empty:
        print("Data is empty, skipping plot.")
        return

    metrics = [('EDD', 'Avg EDD')]
    
    # Identify change types present
    preferred_order = ['Homogeneous Variance Inflation', 'Correlation Structure Change', 'Heterogeneous Variance Inflation']
    available_types = df['ChangeType'].unique()
    change_types = [ct for ct in preferred_order if ct in available_types]
    if not change_types: 
        change_types = available_types

    n_cols = 3
    n_rows = 1
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5))

    if n_cols == 1:
        axes = np.array([axes])
    elif n_cols > 1:
        axes = axes.flatten()

    for j, c_type in enumerate(change_types):
        if j >= n_cols: break
        
        subset = df[df['ChangeType'] == c_type]
        
        # --- Logic for Necessary X-Ticks ---
        current_xticks = None
        current_xticklabels = None
        
        if manual_xticks and not subset.empty:
            xvals = np.sort(subset[x_col].unique())
            if len(xvals) > 1:
                current_xticks = np.linspace(xvals.min(), xvals.max(), 5)
                current_xticklabels = [format_tick_label(v) for v in current_xticks]
            else:
                current_xticks = xvals
                current_xticklabels = [format_tick_label(xvals[0])]

        # Modified: No inner loop needed, plotting single metric
        metric_key, metric_label = metrics[0]
        ax = axes[j] # Modified: 1D indexing
        
        if not subset.empty:
            sns.lineplot(
                data=subset,
                x=x_col,
                y=metric_key,
                hue=hue_col,
                style=style_col,
                hue_order=hue_order,
                style_order=hue_order,
                palette=palette,
                markers=markers if markers else True,
                dashes=dashes if dashes else True,
                ax=ax,
                errorbar=None,
                linewidth=2.5,
                markersize=9
            )
            
            # Modified: Set Title AND Label on the same plot
            ax.set_title(f"{c_type}", fontsize=16)
            # ax.set_xlabel(x_col)
            ax.set_xlabel("")
            if j == 0:
                ax.set_ylabel(metric_label)
            else:
                ax.set_ylabel("")

            # --- Apply the Sparse Ticks ---
            if manual_xticks and current_xticks is not None:
                # Use set_xticks with labels parameter to avoid warnings
                ax.set_xticks(current_xticks, labels=current_xticklabels)

            # Legend handling
            if ax.get_legend() is not None:
                ax.legend(title=hue_col)
                
            sns.despine(ax=ax)
        else:
            ax.axis('off')

    for j in range(len(change_types), n_cols):
        axes[j].axis('off')

    plt.tight_layout(rect=[0, 0.08, 1, 1])
    
    fig.text(0.5, 0.02, x_col, ha='center', va='bottom', fontsize=16)

    return plt

In [None]:
# ==========================================
# 4. Main Execution Flow
# ==========================================

if __name__ == "__main__":
    
    setup_plotting_style()
    df, fig_path = load_and_prep_data()
    if 'FunctionType' in df.columns:
        df = df.rename(columns={'FunctionType': 'TestFunction'})

    # --- Configuration for Figure 1 & 3 ---
    k1 = r"$\rho_{1,0}$"
    k2 = r"$\rho_{1,0.25}$"
    k3 = r"$\rho_{1,0.45}$"
    k4 = r"$\rho_{2}$"
    
    gamma_keys = [k1, k2, k3, k4]
    
    std_palette = dict(zip(gamma_keys, ['tab:blue', 'tab:orange', 'tab:green', 'tab:red']))
    std_markers = dict(zip(gamma_keys, ['o', 's', '^', 'D']))
    
    # Custom Dashes for Gamma Groups
    # "" = solid, (2,2) = dash, etc.
    std_dashes = {
        k1: "",       
        k2: (2, 2),   
        k3: (5, 2),   
        k4: (1, 1)    
    }

    # 1. Figure 1: Power/EDD vs Change Magnitude
    # manual_xticks=True -> Shows only ~5 necessary ticks
    plot_metrics_grid(
        df=df,
        x_col='Change Magnitude',
        hue_col='WeightFunction',
        style_col='WeightFunction',
        fig_path=fig_path,
        hue_order=gamma_keys,
        palette=std_palette,
        markers=std_markers,
        dashes=std_dashes,
        manual_xticks=True 
    )
    
    save_full_path = os.path.join(fig_path, "simu1_weight_magnitude.pdf")
    plt.savefig(save_full_path, bbox_inches='tight')
    print(f"Saved: {save_full_path}")

    # 2. Figure 2: Comparison of TestFunction
    # manual_xticks=True -> Also sparsifies ticks for this plot (Same x-axis as Fig 1)
    if not df.empty:
        plot_metrics_grid(
            df=df,
            x_col='Change Magnitude',
            hue_col='TestFunction', 
            style_col='TestFunction',
            fig_path=fig_path,
            dashes=True, # Auto-assign line styles for different functions
            manual_xticks=True 
        )
    save_full_path = os.path.join(fig_path, "simu2_test_magnitude.pdf")
    plt.savefig(save_full_path, bbox_inches='tight')
    print(f"Saved: {save_full_path}")
    
    # 3. Figure 3: Power/EDD vs Change Point Location (K_Star)
    # manual_xticks=True -> Sparsifies K_Star ticks (e.g. 350, 400, 450, 500, 550)
    # and formats them as integers due to format_tick_label helper.
    plot_metrics_grid(
        # df=df[df['TestFunction'] == 'log'],
        df = df,
        x_col='Change Point',
        hue_col='WeightFunction',
        style_col='WeightFunction',
        fig_path=fig_path,
        hue_order=gamma_keys,
        palette=std_palette,
        markers=std_markers,
        dashes=std_dashes,
        manual_xticks=True,
    )
    for ax in plt.gcf().axes:
        ax.set_xticks([350, 450, 550])
        ax.set_xticklabels(['350', '450', '550'])
    
    save_full_path = os.path.join(fig_path, "simu3_weight_kstar.pdf")
    plt.savefig(save_full_path, bbox_inches='tight')
    print(f"Saved: {save_full_path}")
    
    g2 = sns.catplot(
        data=df, 
        x="Distribution", 
        y="EDD", 
        hue="WeightFunction", 
        col="ChangeType", 
        kind="bar", 
        height=4, 
        aspect=1.2,
        sharey=False,
        errorbar=None,
        palette="Set2"
    )
    g2.set_axis_labels("", "Avg EDD")
    g2.set_titles("{col_name}")
    g2.figure.subplots_adjust(bottom=0.15) 
    g2.figure.text(0.5, 0.02, "Distribution", ha='center', va='bottom', fontsize=16) 
    save_full_path = os.path.join(fig_path, "simu4_distribution_robustness.pdf")
    plt.savefig(save_full_path, bbox_inches='tight')
    print(f"Saved: {save_full_path}")
    plt.show()