In [1]:
import os
import pandas as pd

def aggregate_pred_dataframe(files):
    dfs = [pd.read_csv(f) for f in files]
    final_df = dfs[0].rename(columns={'pred': 'pred_0'})
    for i in range(1, len(dfs)):
        final_df[f'pred_{i}'] = dfs[i]['pred']
    return final_df.reset_index()

def get_preds_with_prefix(prefix, seed_start=0, seed_end=20):
    csv_files = [file for file in os.listdir('preds') if file.endswith('.csv')]
    csv_files.sort()
    len(csv_files)
    preds = []
    for seed in range(seed_start, seed_end):
        filtered = [file for file in csv_files if file.startswith(f'{prefix}_{seed}')]
        if filtered:
            preds.append((filtered[-1], filtered[-2]))
    return preds

def aggregate_preds(preds):
    df_valid = aggregate_pred_dataframe([f'preds/{i[0]}' for i in preds])
    df_test = aggregate_pred_dataframe([f'preds/{i[1]}' for i in preds])
    
    return df_valid, df_test

In [2]:
%matplotlib inline
from sklearn.metrics import confusion_matrix, recall_score, accuracy_score, precision_score, matthews_corrcoef
import matplotlib.pyplot as plt
import numpy as np

def round_dict(d, n):
    return {k: round(v, n) if isinstance(v, float) else v
                for k, v in d.items()}

def generate_mean_ensemble_metrics(df, threshold=0):
    sum_preds = df[list(filter(lambda a: a.startswith('pred_'), df.columns.tolist()))].mean(axis=1)
    final_prediction = (sum_preds > threshold).astype(int)

    # Sensitivity (Recall)
    sensitivity = recall_score(df['target'], final_prediction)

    # Specificity
    tn, fp, fn, tp = confusion_matrix(df['target'], final_prediction).ravel()
    specificity = tn / (tn + fp)

    # Accuracy
    accuracy = accuracy_score(df['target'], final_prediction)

    # Precision
    precision = precision_score(df['target'], final_prediction)
    mcc = matthews_corrcoef(df['target'], final_prediction)
    return {
        "sensitivity": sensitivity,
        "specificity": specificity,
        "accuracy": accuracy,
        "precision": precision,
        "mcc": mcc,
    }
    
def draw_mean_ensemble_thrshold_chart(df_valid, df_test, start=-3, end=1, plot=True):
    # Create a list of thresholds to test
    thresholds = np.arange(start, end, 0.1)  # Adjust the step size as necessary
    valid_mccs = []
    test_metrics = []

    # Loop through thresholds and compute MCC
    for threshold in thresholds:
        metrics = generate_mean_ensemble_metrics(df_valid, threshold)
        valid_mccs.append(metrics['mcc'])
        
        metrics_test = generate_mean_ensemble_metrics(df_test, threshold)
        test_metrics.append(metrics_test)

    # Identify threshold with the best MCC
    best_threshold_arg = np.argmax(valid_mccs)
    best_threshold = thresholds[best_threshold_arg]

    label = f'Best Threshold: {best_threshold:.1f}, Valid MCC: {valid_mccs[best_threshold_arg]:.3f}, Test MCC: {test_metrics[best_threshold_arg]["mcc"]:.3f}'
    # Plot
    if plot:
        plt.figure(figsize=(10, 6))
        plt.plot(thresholds, valid_mccs, label='Valid MCC', color='blue')
        plt.plot(thresholds, [i['mcc'] for i in test_metrics], label='Test MCC', color='green')
        plt.axvline(x=best_threshold, color='red', linestyle='--', label=label)
        plt.xlabel('Threshold')
        plt.ylabel('MCC Value')
        plt.title('MCC vs. Threshold')
        plt.legend()
        plt.grid(True)
        plt.show()
    return {
        'best_threshold': best_threshold, 
        'valid_mcc': valid_mccs[best_threshold_arg],
        **test_metrics[best_threshold_arg]
    }


def random_small_ensembles(preds, n, trial):
    from random import sample
    df = pd.DataFrame()
    for i in range(trial):
        sample_preds = sample(preds, n)
        df_valid = aggregate_pred_dataframe([f'preds/{i[0]}' for i in sample_preds])
        df_test = aggregate_pred_dataframe([f'preds/{i[1]}' for i in sample_preds])
        new_row = [draw_mean_ensemble_thrshold_chart(df_valid, df_test, start=-3, end=1, plot=False)]
        new_df = pd.DataFrame(new_row)
        df = pd.concat([df, new_df], ignore_index=True)

    return df


def summarize_prefix(prefix, n):
    preds = get_preds_with_prefix(prefix, seed_start=0, seed_end=20)
    df_valid, df_test = aggregate_preds(preds)
    print(f'Ensemble of all {len(preds)} models:')
    print(round_dict(draw_mean_ensemble_thrshold_chart(df_valid, df_test, start=-3, end=1, plot=False), 4))

    print(f'Ensemble of {n} random models:')
    df = random_small_ensembles(preds, n, 10)
    return df.aggregate(['mean', 'std', 'max']).T


In [3]:
summarize_prefix('v001', n=10)

Ensemble of all 20 models:


KeyboardInterrupt: 

In [None]:
summarize_prefix('v002', n=10)

Ensemble of all 20 models:
{'best_threshold': -1.0, 'valid_mcc': 0.6344, 'sensitivity': 0.5502, 'specificity': 0.9902, 'accuracy': 0.9674, 'precision': 0.7533, 'mcc': 0.6277}
Ensemble of 10 random models:


Unnamed: 0,mean,std,max
best_threshold,-1.35,0.334166,-0.9
valid_mcc,0.628503,0.004834,0.639985
sensitivity,0.566507,0.01591,0.591707
specificity,0.988144,0.001933,0.990853
accuracy,0.966306,0.001052,0.967702
precision,0.724548,0.027103,0.764574
mcc,0.623343,0.005105,0.62997


In [None]:
summarize_prefix('v003', n=10)

Ensemble of all 20 models:
{'best_threshold': -1.0, 'valid_mcc': 0.6311, 'sensitivity': 0.5646, 'specificity': 0.9894, 'accuracy': 0.9674, 'precision': 0.7437, 'mcc': 0.6316}
Ensemble of 10 random models:


Unnamed: 0,mean,std,max
best_threshold,-0.91,0.366515,-0.5
valid_mcc,0.626113,0.003223,0.632314
sensitivity,0.56236,0.017296,0.596491
specificity,0.988997,0.001829,0.991637
accuracy,0.966901,0.000994,0.96828
precision,0.737745,0.025957,0.77931
mcc,0.627232,0.006275,0.634304


In [None]:
summarize_prefix('v004', n=10)

Ensemble of all 13 models:
{'best_threshold': -1.6, 'valid_mcc': 0.6235, 'sensitivity': 0.5805, 'specificity': 0.9863, 'accuracy': 0.9653, 'precision': 0.6987, 'mcc': 0.619}
Ensemble of 10 random models:


Unnamed: 0,mean,std,max
best_threshold,-0.93,0.632543,0.1
valid_mcc,0.622823,0.002391,0.626438
sensitivity,0.558373,0.0273,0.594896
specificity,0.988361,0.002257,0.991637
accuracy,0.966091,0.00084,0.967124
precision,0.726133,0.029762,0.771971
mcc,0.619151,0.004823,0.627392


In [4]:
summarize_prefix('v005', n=10)

Ensemble of all 20 models:
{'best_threshold': -1.0, 'valid_mcc': 0.6307, 'sensitivity': 0.555, 'specificity': 0.9895, 'accuracy': 0.967, 'precision': 0.742, 'mcc': 0.6253}
Ensemble of 10 random models:


Unnamed: 0,mean,std,max
best_threshold,-1.8,0.678233,-0.6
valid_mcc,0.628339,0.004138,0.6354
sensitivity,0.582616,0.016644,0.61563
specificity,0.986175,0.002256,0.989807
accuracy,0.965273,0.00131,0.967372
precision,0.698896,0.028599,0.748927
mcc,0.619921,0.005779,0.629431


In [None]:
summarize_prefix('v006', n=10)

Ensemble of all 20 models:
{'best_threshold': -2.1, 'valid_mcc': 0.6283, 'sensitivity': 0.5933, 'specificity': 0.9862, 'accuracy': 0.9659, 'precision': 0.7019, 'mcc': 0.6277}
Ensemble of 10 random models:


Unnamed: 0,mean,std,max
best_threshold,-1.75,0.330824,-1.4
valid_mcc,0.62516,0.003694,0.631102
sensitivity,0.580064,0.010948,0.596491
specificity,0.987569,0.001735,0.989459
accuracy,0.966463,0.001265,0.967537
precision,0.719259,0.023954,0.745798
mcc,0.628647,0.00868,0.639425
