In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import os
import numpy as np

In [None]:
folder_path = '../data'
datasets = [name for name in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, name))]

In [None]:
def plot(dataset, figsize):
    # Define the desired order of methods with 'rf' renamed to 'mmif'
    desired_order = ['constant', 'knn', 'linear', 'mmit', 'mmif', 'aft_xgb', 'mlp']

    datasize = pd.read_csv(f'../data/{dataset}/targets.csv').shape[0]
    n_features = pd.read_csv(f'../data/{dataset}/features.csv').shape[1]

    # Load the CSV file
    df = pd.read_csv(f'../loss_csvs/{dataset}.csv')

    # Replace 'rf' with 'mmif' in the 'method' column
    df['method'] = df['method'].replace({'rf': 'mmif', 'aft_xgboost_original': 'aft_xgb'})

    # Filter the dataframe based on the desired order of methods
    df = df[df['method'].isin(desired_order)]

    # Drop rows with NaN values
    df = df.dropna()

    # Get log of loss
    df['loss'] = np.log10(df['loss'] + 1e-10)

    # Filter methods and ensure the desired order
    df['method'] = pd.Categorical(df['method'], categories=desired_order, ordered=True)
    df = df.sort_values('method')

    # Create a figure with one subplot (mean ± SD plot only)
    fig, ax2 = plt.subplots(figsize=figsize)

    means = []
    sds = []
    for method in desired_order:
        method_losses = df[df['method'] == method]['loss']
        means.append(method_losses.mean())
        sds.append(method_losses.std())

    # Plot mean ± SD
    y_pos = np.arange(len(desired_order))
    ax2.errorbar(means, y_pos, xerr=sds, fmt='o', color='black', ecolor='gray', capsize=5, label='Mean ± SD')
    ax2.set_yticks(np.arange(len(desired_order)))  # Ensure y-ticks match the method indices
    ax2.set_yticklabels(desired_order)  # Set y-ticks to be the method names
    ax2.set_xlabel('log_test_squared_hinge_loss')
    ax2.set_title(f"{dataset} ({datasize} instances -- {n_features} features)")
    ax2.set_ylim(-0.5, len(desired_order) - 0.5)
    ax2.grid(True)

    # Save the plot as a PNG or JPG in the 'pngs' folder
    if not os.path.exists('pngs'):
        os.makedirs('pngs')  # Create the 'pngs' folder if it doesn't exist

    # Save as PNG or JPG with tight bounding box and high DPI for better quality
    plt.savefig(f'pngs/{dataset}_plot.png', dpi=300, bbox_inches='tight')  # or change to .jpg if preferred

    # Adjust layout to prevent clipping
    plt.tight_layout()

    # Show the plot (optional)
    plt.show()

In [None]:
for dataset in ['simulated.linear', 'simulated.abs', 'simulated.sin']:
    plot(dataset, (4,3))