In [54]:
import pandas as pd
import matplotlib.pyplot as plt

In [55]:

def consolidate_model_performance(xlsx_path, algorithm_type):
    # Load the data from the Excel file
    df = pd.read_excel(xlsx_path)

    # Clean up column names by stripping any leading/trailing spaces
    df.columns = df.columns.str.strip()

    # Filter the DataFrame based on the algorithm type
    filtered_df = df[df['Algorithm'] == algorithm_type]

    # Initialize a list to store DataFrame slices before concatenation
    results_list = []

    # Group by model and find the maximum improvement and the corresponding threshold for each metric
    for model, group in filtered_df.groupby('Model'):
        for metric in ['Precision Improvement (%)', 'Recall Improvement (%)', 'F1 Improvement (%)']:
            max_value = group[metric].max()
            # Find the threshold at which the maximum improvement occurs
            max_threshold = group.loc[group[metric].idxmax(), 'Threshold'] if max_value > 0 else None
            # Create a dictionary of the results
            result_dict = {
                'Model': model,
                'Metric': metric,
                'Max Improvement': max_value,
                'Threshold at Max': max_threshold
            }
            # Convert the dictionary to a DataFrame and append it to the list
            results_list.append(pd.DataFrame([result_dict]))

    # Concatenate all DataFrame slices in the list into a single DataFrame
    results_df = pd.concat(results_list, ignore_index=True)

    # Reshape the DataFrame for better readability
    return results_df.pivot(index='Model', columns='Metric', values=['Max Improvement', 'Threshold at Max'])


In [56]:
def plot_improvements(xlsx_path, algorithm_type, commodity_type):
    # Load the data
    df = pd.read_excel(xlsx_path)

    # Clean up column names by stripping any leading/trailing spaces
    df.columns = df.columns.str.strip()

    # Filter based on the algorithm type
    df = df[df['Algorithm'] == algorithm_type]

    # Group data by 'Model'
    grouped = df.groupby('Model')

    # Plotting separate charts for Precision and Recall
    for name, group in grouped:
        # Remove NaN values and check if data is valid
        prec_data = group['Precision Improvement (%)'].dropna()
        rec_data = group['Recall Improvement (%)'].dropna()

        if prec_data.empty or rec_data.empty or prec_data.nunique() == 1 or rec_data.nunique() == 1:
            continue  # Skip this plot if data is insufficient or non-variable

        fig, ax1 = plt.subplots(figsize=(10, 5))

        color = 'tab:red'
        ax1.set_xlabel('Threshold')
        ax1.set_ylabel('Precision Improvement (%)', color=color)
        ax1.plot(group['Threshold'], prec_data, label='Precision Improvement (%)', marker='o', color=color)
        ax1.tick_params(axis='y', labelcolor=color)
        
        ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis
        color = 'tab:blue'
        ax2.set_ylabel('Recall Improvement (%)', color=color)  # we already handled the x-label with ax1
        ax2.plot(group['Threshold'], rec_data, label='Recall Improvement (%)', marker='o', color=color)
        ax2.tick_params(axis='y', labelcolor=color)

        # Title and grid
        plt.title(f'[{commodity_type}] Precision and Recall Improvement vs. Threshold for {name} using {algorithm_type}')
        ax1.grid(True)

        # Show plot
        fig.tight_layout()  
        plt.show()



In [57]:
TYPE = 'threshold'
ALGORITHM = ['correction', 'detection_correction']
COMMODITY = 'nickel_shift_20'

In [58]:
results_df = consolidate_model_performance(f'out/{TYPE}/{COMMODITY}_results.xlsx', 'detection_correction')
results_df

  results_df = pd.concat(results_list, ignore_index=True)


Unnamed: 0_level_0,Max Improvement,Max Improvement,Max Improvement,Threshold at Max,Threshold at Max,Threshold at Max
Metric,F1 Improvement (%),Precision Improvement (%),Recall Improvement (%),F1 Improvement (%),Precision Improvement (%),Recall Improvement (%)
Model,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2
CNN_Attention_128_filters_5_kernels_predictions,0.043386,0.000842,0.074074,0.4,0.4,0.4
CNN_Attention_256_filters_7_kernels_predictions,0.0,0.0,0.0,,,
CNN_Attention_64_filters_3_kernels_predictions,0.074271,0.026913,0.413793,0.1,0.1,0.45
CNN_Attention_64_filters_5_kernels_predictions,0.0,0.0,0.0,,,


In [59]:
# for COMMODITY in [
#   'cobalt_shift_20', 'copper_shift_20', 'magnesium_shift_20', 'nickel_shift_20',
# ]:
#     for ALGO in ['correction', 'detection_correction']:
#         path = f'out/{TYPE}/{COMMODITY}_results.xlsx'
#         plot_improvements(path, ALGO, COMMODITY)