# Setup

## Data Import

In [None]:
from itertools import combinations, permutations
from pathlib import Path

import numpy as np
import pandas as pd
from sqlite3 import connect

from matplotlib import pyplot as plt
import seaborn as sns

np.random.seed(707260)

In [None]:
con = connect('../results.db')
tables = pd.read_sql(
    "SELECT * FROM sqlite_master", 
    con=con
).loc[:, 'name']
con.close()

In [None]:
df_map = {}

bad_vals = 0

analysis_idx = ['seg_algo', 'dataset', 'model', 'weight', 'ori', 'prep']

con = connect('../results.db')
for t in tables:
    # Pull the dataframe from the database
    try:
        df = pd.read_sql(
            f"SELECT * FROM {t}", 
            con=con
        )
    except:
        print(f"Failed to read table {t}, ignoring it")
        bad_vals += 1
        continue

    # If the table represents a study which wasn't run to completion, end early and report it
    if df.shape[0] < 1000:
        # print(f"Study {t} was not completed")
        bad_vals += 1
        continue

    # Split the DataFrame's label into its components
    label_comps = t.split('__')

    # Pull the model label from it
    model = label_comps[1]

    # The rest of the components are in the final tag
    final_comps = label_comps[-1].split('_')

    # Clinical needs to be treated special:
    if final_comps[0] == 'clinical':
        seg_algo = 'none'
        dataset = 'clinical'
        ori = 'none'
        weight = 'none'
        prep = '_'.join(final_comps[2:])
    # The rest have the first index as the segmentation algorithm
    else:
        seg_algo = final_comps[0]
        if final_comps[1] == 'full':
            dataset = final_comps[1]
            ori = final_comps[2]
            weight = final_comps[3]
            prep = '_'.join(final_comps[4:])
        elif final_comps[1] == 'img':
            dataset = '_'.join(final_comps[1:3])
            ori = final_comps[3]
            weight = final_comps[4]
            prep = '_'.join(final_comps[5:])
        
    df_key = "_".join([seg_algo, dataset, model, ori, weight, prep])
    
    # Store the components in the dataframe itself
    df['seg_algo'] = seg_algo
    df['dataset'] = dataset
    df['model'] = model
    df['weight'] = weight
    df['ori'] = ori
    df['prep'] = prep
    
    # Track the resulting dataframe via the result
    df_map[df_key] = df

con.close()

print(f"\nTotal No. bad values: {bad_vals}")

In [None]:
len(df_map)

## Performance Metric Stacking

All metrics in the below index list are tracked for all analyses, so are safe to query (and stack) from all analytical permutations

In [None]:
shared_performance_metric_idxs = [
    "objective",
    "balanced_accuracy (validate)",
    "roc_auc (validate)",
    "log_loss (validate)",
    "balanced_accuracy (test)",
    "roc_auc (test)",
    "log_loss (test)",
    "importance_by_permutation (test)"
    
]

In [None]:
study_idxs = [
    "replicate",
    "trial"
]

In [None]:
def stack_performance_metrics():
    sub_dfs = []
    for df in df_map.values():
        sub_df = df.loc[:, [*analysis_idx, *study_idxs, *shared_performance_metric_idxs]]
        sub_dfs.append(sub_df)
    return pd.concat(sub_dfs)

performance_metric_df = stack_performance_metrics()

In [None]:
performance_metric_df

# Patient Metric Distributions

## Data Importing

In [None]:
clinical_metric_df = pd.read_csv("../deepseg_data/clinical_only.tsv", sep='\t')
clinical_metric_df

## mJOA

Setup

In [None]:
def plot_distributions(data, cmap, legend_elements, xlabel, title, mean_offset=0, flip_mean_rot=False):
    # Get the appropriate ranges for the data
    min_range = int(np.min(data))-1
    max_range = int(np.max(data))+1
    
    # Bin the data
    hist, bins = np.histogram(
        data, 
        np.array(range(min_range, max_range))+.1
    )
    
    # Generate the figure
    fig, ax = plt.subplots()
        
    # Iteratively color code the bars
    for t, c in cmap.items():
        mask = bins < t
        to_display = np.array(range(min_range, t))+0.5
        vals = hist[mask[:-1]]
        ax.bar(
            to_display, vals,
            width=1, color=c,
            align='edge',
            edgecolor='black'
        )
        
    # Add a mean line
    data_mean = np.mean(data)
    ax.axvline(data_mean, ls='--', c='black')
    if flip_mean_rot:
        ax.text(data_mean-0.5, ax.get_ylim()[1]-mean_offset, f"Mean ({data_mean:.4})", rotation=90)
    else:
        ax.text(data_mean+0.05, ax.get_ylim()[1]-mean_offset, f"Mean ({data_mean:.4})", rotation=-90)
        
    # Add in the legend
    ax.legend(handles=legend_elements)
    
    # Add in labels
    ax.set_xlabel(xlabel)
    ax.set_ylabel('Count')
    ax.set_title(title)
    
    # Return the figure and axis
    return fig, ax

In [None]:
# Limits so that all plots have consistent range
xlim_min = int(np.min([*clinical_metric_df['mJOA initial'], *clinical_metric_df['mJOA 12 months']]))-1
xlim_max = int(np.max([*clinical_metric_df['mJOA initial'], *clinical_metric_df['mJOA 12 months']]))+1

ylim_min = 0
ylim_max = int(np.max([
    *np.histogram(clinical_metric_df['mJOA initial'], np.array(range(xlim_min, xlim_max))+.1)[0],
    *np.histogram(clinical_metric_df['mJOA 12 months'], np.array(range(xlim_min, xlim_max))+.1)[0]
]))+5

# Color threshold map
severity_cmap = {
    18: 'blue',
    17: 'green',
    14: 'gold',
    11: 'red'
}

# Generate a custom legend
from matplotlib.patches import Patch
legend_elements = [
    Patch(facecolor='red', edgecolor='black', label='Severe'),
    Patch(facecolor='gold', edgecolor='black', label='Moderate'),
    Patch(facecolor='green', edgecolor='black', label='Mild'),
    Patch(facecolor='blue', edgecolor='black', label='Healthy'),
]

# DCM Severity labelling
clinical_metric_df['DCM Severity initial'] = 'Severe'
clinical_metric_df.loc[clinical_metric_df['mJOA initial'] > 11, 'DCM Severity initial'] = 'Moderate'
clinical_metric_df.loc[clinical_metric_df['mJOA initial'] > 14, 'DCM Severity initial'] = 'Mild'
clinical_metric_df.loc[clinical_metric_df['mJOA initial'] > 17, 'DCM Severity initial'] = 'Healthy'

clinical_metric_df['DCM Severity 12 months'] = 'Severe'
clinical_metric_df.loc[clinical_metric_df['mJOA 12 months'] > 11, 'DCM Severity 12 months'] = 'Moderate'
clinical_metric_df.loc[clinical_metric_df['mJOA 12 months'] > 14, 'DCM Severity 12 months'] = 'Mild'
clinical_metric_df.loc[clinical_metric_df['mJOA 12 months'] > 17, 'DCM Severity 12 months'] = 'Healthy'

# Output path for the files
mjoa_dist_out_path = Path('figures/mjoa_dist')
if not mjoa_dist_out_path.exists():
    mjoa_dist_out_path.mkdir(parents=True)

### Initial

In [None]:
# Plot the data
fig, ax = plot_distributions(
    clinical_metric_df['mJOA initial'], severity_cmap, legend_elements,
    'mJOA', 'Pre-Treatment mJOA Scores (Full)', 20
)

# Plot the total number of each severity class as text
severity_counts = clinical_metric_df['DCM Severity initial'].value_counts()
ax.text(9, 15, f"({severity_counts['Severe']})", c='black', size=12, horizontalalignment='center')
ax.text(14, 44.5, f"({severity_counts['Moderate']})", c='black', size=12, horizontalalignment='center')
ax.text(16.5, 33, f"({severity_counts['Mild']})", c='black', size=12, horizontalalignment='center')
ax.text(18, 2.5, f"({severity_counts['Healthy']})", c='black', size=12, horizontalalignment='center')

# Save and show the result
fig.savefig(mjoa_dist_out_path / 'pre_treatment_mjoa.svg')
plt.show()

### 12 Month

In [None]:
# Plot the data
fig, ax = plot_distributions(
    clinical_metric_df['mJOA 12 months'], severity_cmap, legend_elements,
    'mJOA', 'Post-Treatment mJOA Scores (Full)', 20, flip_mean_rot=True
)

# Plot the total number of each severity class as text
severity_counts = clinical_metric_df['DCM Severity 12 months'].value_counts()
ax.text(8.5, 3, f"({severity_counts['Severe']})", c='black', size=12, horizontalalignment='center')
ax.text(13.5, 36, f"({severity_counts['Moderate']})", c='black', size=12, horizontalalignment='center')
ax.text(16, 45, f"({severity_counts['Mild']})", c='black', size=12, horizontalalignment='center')
ax.text(18, 37, f"({severity_counts['Healthy']})", c='black', size=12, horizontalalignment='center')

# Save and show the result
fig.savefig(mjoa_dist_out_path / 'post_treatment_mjoa.svg')
plt.show()

### mJOA Delta

In [None]:
# Define a new color scheme and legend for this new style of data
delta_cmap = {
    8: 'springgreen',
    0: 'white',
    -1: 'salmon'
}

delta_legend_elements = [
    Patch(facecolor='springgreen', edgecolor='black', label='Improved'),
    Patch(facecolor='white', edgecolor='black', label='No Change'),
    Patch(facecolor='salmon', edgecolor='black', label='Declined'),
]

xticks = (
    list(range(-8, 9, 2)),
    list(range(-8, 9, 2))
)

deltas = clinical_metric_df['mJOA 12 months'] - clinical_metric_df['mJOA initial']

In [None]:
# Plot the deltas
fig, ax = plot_distributions(
    deltas, delta_cmap, delta_legend_elements, 
    "mJOA Change", 'Change in mJOA Over 1 Year (Full)', 20, flip_mean_rot=True
)

# Plot the total number of each severity class as text
change_counts = pd.cut(
    deltas, 
    [-20, -1, 0, 20], 
    labels=['Declined', 'No Change', 'Improved']
).value_counts()
ax.text(-4.5, 9, f"({change_counts['Declined']})", c='black', size=12, verticalalignment='center')
ax.text(-0.6, 40, f"({change_counts['No Change']})", c='black', size=12, verticalalignment='center')
ax.text(4, 32, f"({change_counts['Improved']})", c='black', size=12, verticalalignment='center')

# Save and show the result
fig.savefig(mjoa_dist_out_path / 'treatment_mjoa_delta.svg')
plt.show()

## Hirayabashi Recovery Ratio Distribution

Setup

In [None]:
from scipy.stats import gaussian_kde

# Plot the KDE distribution onto an existing plot
def plot_kde(ax, values, c='black', ls='-', label=None):
    kde = gaussian_kde(values)
    kde.covariance_factor = lambda: 0.15
    kde._compute_covariance()
    xs = np.linspace(np.min(values), np.max(values), 200)
    ys = kde(xs)
    ys /= np.linalg.norm(ys)
    if label == None:
        ax.plot(xs, ys, ls=ls, c=c)
    else:
        ax.plot(xs, ys, ls=ls, c=c, label=label)

# Clean out invalid values from the set
def clean_vals(df):
    df2 = df[df != -np.inf]
    df2 = df2.dropna()
    return df2

# Adds important reference lines to the plot
def draw_line_references(ax):
    # Significant improvement
    ax.axvline(0.5, ls='-.', c='grey')
    
    # Baselines
    ax.axhline(0, ls=":",  c='lightgrey') 
    ax.axvline(0, ls=":",  c='lightgrey')

# The HRR Equation, for immediate reference within the plot
hirabayashi_equation = r"HRR = $\frac{\mathrm{mJOA (1 Year)} - \mathrm{mJOA (Initial)}}{18 - \mathrm{mJOA (Initial)}}$"

In [None]:
# Get the HRR for our patients, skipping over initially healthy patients who could not improve whatsoever
hrr_df = clinical_metric_df.loc[clinical_metric_df['DCM Severity initial'] != "Healthy", 'HRR']

# Generate the initial plot
fig, ax = plt.subplots()

# Plot our reference lines
draw_line_references(ax)

# Plot the distributions by their initial severity class
plot_kde(
    ax, clean_vals(hrr_df[clinical_metric_df['DCM Severity initial'] == 'Severe']), ls='--', c='red', label='Severe'
)
plot_kde(
    ax, clean_vals(hrr_df[clinical_metric_df['DCM Severity initial'] == 'Moderate']), ls='--', c='gold', label='Moderate'
)
plot_kde(
    ax, clean_vals(hrr_df[clinical_metric_df['DCM Severity initial'] == 'Mild']), ls='--', c='green', label='Mild'
)

# Plot the overall distribution
plot_kde(ax, hrr_df, c='blue', label='All')

# Calculate the ratio above and below the HRR significance threshold, and add it
good_ratio = np.sum(hrr_df >= 0.5)/hrr_df.shape[0]
fair_ratio = np.sum(hrr_df < 0.5)/hrr_df.shape[0]

ax.text(0.7, 0.238, f"{good_ratio: .2f}", c='purple')
ax.text(-0.5, 0.238, f"{fair_ratio: .2f}", c='purple')

# Add axis labels
ax.set_xlabel('Hirabayashi Recovery Ratio (HRR)')
ax.set_ylabel('Normalized Kernel Density Estimate')

# Add a legend
ax.legend(title='Pre-Surgical DCM Severity')

# Add hirabayashi equation directly to plot
ax.text(-8, 0.15, hirabayashi_equation)

# Add a title
ax.set_title("Distribution of Hirabayashi Recovery Ratio")

plt.tight_layout()

fig.savefig(mjoa_dist_out_path / 'hirabayashi_ratios.svg')

plt.show()

## Demographics

In [None]:
def plot_continuous_demographics(col):
    sns.displot(clinical_metric_df, x=col)
    plt.title(f"Patient Distribution ({col})")
    plt.xlabel(col)
    plt.ylabel("Count")
    plt.tight_layout()
    plt.savefig(f"figures/demo_dist/{'_'.join(col.lower().split(' '))}_dist.svg")
    plt.show()

In [None]:
continuous_demographic_cols = [
    "Age",
    "BMI"
]

In [None]:
for c in continuous_demographic_cols:
    plot_continuous_demographics(c)

In [None]:
def plot_categorical_demographics(col):
    col_counts = clinical_metric_df[col].value_counts()
    plt.pie(col_counts, labels=None, autopct=lambda x: f'{x: .2f}%')
    plt.legend(labels=col_counts.index)
    plt.title(f"Patient Distribution ({col})")
    plt.tight_layout()
    plt.savefig(f"figures/demo_dist/{'_'.join(col.lower().split(' '))}_dist.svg")
    plt.show()

In [None]:
categorical_demographic_cols = [
    "Sex",
    "Work Status (Category)",
    "Symptom Duration"
]

In [None]:
for c in categorical_demographic_cols:
    plot_categorical_demographics(c)

# Best across Trial

## Utility Functions

In [None]:
best_across_trials_idx = [*analysis_idx, 'Mean', 'STD']
best_across_trials_idx

In [None]:
# Gets the values of one column when the value of another is among the n-highest (default to n=1)
def get_peak_at_max_other(target_col, other_col, df=performance_metric_df, n=1) -> pd.DataFrame:
    # Get the best value per analytical grouping and replicate across all trials
    peak_value_df = df.sort_values(by=other_col).groupby([*analysis_idx, 'replicate']).tail(n)

    # Set up the return dataframe
    analysis_groups = peak_value_df.reset_index().groupby(analysis_idx)
    value_means = analysis_groups[target_col].mean()
    value_stds = analysis_groups[target_col].std()
    return_df = pd.DataFrame(index=list(value_means.index))
    return_df['Mean'] = value_means
    return_df['STD'] = value_stds

    # Return the result
    return return_df

## Balanced Accuracy (Test at Peak Validation)

### Test @ Peak Validation **[MAIN RESULT]**

In [None]:
get_peak_at_max_other('balanced_accuracy (test)', 'balanced_accuracy (validate)').sort_values(by='Mean').tail(10)

### Test @ Peak Test [Theoretical Potential]

In [None]:
get_peak_at_max_other('balanced_accuracy (test)', 'balanced_accuracy (test)').sort_values(by='Mean').tail(10)

# Performance Across Trials

## Utility Functions

In [None]:
def plot_average_performance_across_trials(df, metric, grouping, fpath):
    # Plot the average and standard deviation
    sns.lineplot(data=df, x='trial', y=metric, hue=grouping)

    # Add details
    plt.title(f'By {grouping.capitalize()} (Average)')
    plt.tight_layout()

    # Save and show the plot
    plt.savefig(fpath)
    plt.show()

## Balanced Accuracy (Test)

In [None]:
output_dir = Path("figures/bacc_performance/")
for i in analysis_idx:
    plot_average_performance_across_trials(performance_metric_df, 'balanced_accuracy (test)', i, output_dir/f'bacc_avg_by_{i}.png')

## Balanced Accuracy (Test) at Peak Balanced Accuracy (Validate)

In [None]:
def plot_metric_at_peak_other_across_trials(df, metric, other, grouping, fpath):
    # Reformat the data to be max by trial/replicate grouping
    tmp_df = df.sort_values(other).groupby(['replicate', 'trial', grouping]).tail(1).reset_index()
    
    # Plot the average and standard deviation
    sns.lineplot(data=tmp_df, x='trial', y=metric, hue=grouping)

    # Add details
    plt.title(f'By {grouping.capitalize()} (B.Acc Test @ Peak Validation)')
    plt.tight_layout()

    # Save and show the plot
    plt.savefig(fpath)
    plt.show()

In [None]:
for i in analysis_idx:
    plot_metric_at_peak_other_across_trials(performance_metric_df, 'balanced_accuracy (test)', 'balanced_accuracy (validate)', i, output_dir/f'bacc_test_at_peak_validate_by_{i}.png')

## Balanced Accuracy (Test) Weighted by Balanced Accuracy (Validated)


In [None]:
def weighted_std(vals, weights):
    mean_val = np.average(vals, weights=weights)
    std_vals = np.average((vals-mean_val)**2, weights=weights)
    return std_vals

In [None]:
def metric_weighted_by_other(df, metric, weight, grouping, fpath):
    # Calculate the weighted metrics from the original dataset
    df_groupedby = df.loc[:, [grouping, *study_idxs, metric, weight]].groupby([grouping, 'trial'])
    mean_vals = df_groupedby.apply(lambda x: np.average(x[metric], weights=x[weight]), include_groups=False)
    std_vals = df_groupedby.apply(lambda x: weighted_std(x[metric], x[weight]), include_groups=False)
    
    sub_df = pd.DataFrame()
    sub_df['Mean'] = mean_vals
    sub_df['STD'] = std_vals

    # Plot each of them iteratively, w/ weighted mean and std
    fig, ax = plt.subplots(1)
    group_options = set(df[grouping])
    for i, g in enumerate(group_options):
        # Plot the main line
        y = sub_df.reset_index().query(f"{grouping} == '{g}'")
        y_mean = y.groupby('trial')['Mean'].mean()
        ax.plot(y_mean, label=g)

        # Plot the (weighted) standard deviation fills
        y_std = y.groupby('trial')['STD'].mean()
        ax.fill_between(np.arange(y_std.shape[0]), y_mean+y_std, y_mean-y_std, facecolor=f'C{i}', alpha=0.2)

    # Add other plotted elements
    plt.xlabel('Trial')
    plt.ylabel('Weighted Average')
    plt.legend(title=grouping)
    plt.show()

In [None]:
for i in analysis_idx:
    metric_weighted_by_other(performance_metric_df, 'balanced_accuracy (test)', 'balanced_accuracy (validate)', i, output_dir/f'bacc_weighted_avg_by_{i}.png')

# Statistical Tests

## Setup

In [None]:
from itertools import combinations, permutations

from scipy.stats import normaltest, ranksums, kruskal

Target metric gathering function

In [None]:
# Absolute peak values by replicate, mean and std
def get_best_per_replicate(target_value):
    component_dfs = []
    for k, df in df_map.items():
        peak_df = df.sort_values(by=target_value).groupby('replicate').last()
        peak_df = peak_df.loc[:, [*analysis_idx, 'trial', target_value]]
        component_dfs.append(peak_df)
    result_df = pd.concat(component_dfs).reset_index()
    return result_df

In [None]:
# Values of one metric, sampled at the peak value of another, per-replicate mean and STD sampled
def get_val_at_best_other_per_replicate(target, other):
    component_dfs = []
    for k, df in df_map.items():
        peak_df = df.sort_values(by=other).groupby('replicate').last()
        peak_df = peak_df.loc[:, [*analysis_idx, 'trial', target]]
        component_dfs.append(peak_df)
    result_df = pd.concat(component_dfs).reset_index()
    return result_df

In [None]:
def evaluate_normality(df, query_key, target):
    isnormal = {}
    query_set = set(replicate_best_bacc_df[query_key])

    for k in query_set:
        x = df.query(f"{query_key} == '{k}'")[target]
        isnormal[k] = [normaltest(x).pvalue]

    # Save the results as a dataframe
    return_df = pd.DataFrame.from_dict(isnormal).T
    return_df.columns = ['p-value']
    return return_df

In [None]:
alt_keys = {
    'two-sided': '!=',
    'greater':   '>',
    'less':      '<'
}

def paired_rankedsum(df, query, target, alternative='two-sided'):
    pvals = {}
    query_set = set(df[query])

    # Caclulate the native rankedsum p-value for each pair of datasets, testing whether the former's value is greater than the latters
    for v1, v2 in permutations(query_set, 2):
        x1 = df.query(f"{query} == '{v1}'")[target]
        x2 = df.query(f"{query} == '{v2}'")[target]
        p = ranksums(x1, x2, alternative=alternative).pvalue
        pvals[f"{v1} {alt_keys[alternative]} {v2}"] = [p]

    # Save the results as a dataframe
    return_df = pd.DataFrame.from_dict(pvals).T
    return_df.index.name = 'Comparison'
    return_df.columns = ['p']
    return return_df

In [None]:
def evaluate_kw(df, grouping, target):
    query_set = set(df[grouping])
    samples = [df.query(f"{grouping} == '{q}'")[target] for q in query_set]
    return kruskal(*samples).pvalue

## Testing Balanced Accuracy

### Testing @ Peak Validation

#### Raw Performance

In [None]:
target = 'balanced_accuracy (test)'
other = 'balanced_accuracy (validate)'
replicate_test_at_peak_bacc_df = get_val_at_best_other_per_replicate(target, other)
replicate_test_at_peak_bacc_df

#### Ranked-Sum Grouping Comparisons

In [None]:
# Calculate the p-values for whether one experimental permutation has greater average balanced accuracy performance than another
sub_dfs = []
for k in analysis_idx:
    tmp_df = paired_rankedsum(replicate_test_at_peak_bacc_df, k, target, alternative='greater')
    sub_dfs.append(tmp_df)

sig_test_at_peak_valid_df = pd.concat(sub_dfs).sort_values('p')

# Calculate the corrected p-value significance as well
n_samples = sig_test_at_peak_valid_df.shape[0]
sig_test_at_peak_valid_df['significance'] = ''
for i, t in enumerate([0.05, 0.01, 0.001]):
    sig_test_at_peak_valid_df.loc[sig_test_at_peak_valid_df['p']*n_samples < t, 'significance'] = '*'*(i+1)

sig_test_at_peak_valid_df.reset_index().head(50)

#### Kruskal-Wallace 

In [None]:
# Using Kruskal-Wallace, confirm that there is a significant difference in the best-case performance for each analytical variation
kw_pvals = {}
for i in analysis_idx:
    kw_pvals[i] = [evaluate_kw(replicate_test_at_peak_bacc_df, i, 'balanced_accuracy (test)')]
kw_df = pd.DataFrame.from_dict(kw_pvals).T
kw_df.columns = ['p']

# Calculate the corrected p-value significance as well w/ Bonferroni correction
kw_df['significance'] = ''
n_samples = kw_df.shape[0]
for i, t in enumerate([0.05, 0.01, 0.001]):
    kw_df.loc[kw_df['p']*n_samples < t, 'significance'] = '*'*(i+1)

kw_df

# Feature Importance

## Utility Functions

In [None]:
def format_feature_imp(val):
    # Strip leading and trailing brackets
    val = val[1:-2]

    # Create a dictionary from the remaining components
    imp_dict = dict()
    for v in val.split(', '):
        vcomps = v.split(': ')
        k = ': '.join(vcomps[:-1])
        v = float(vcomps[-1])
        imp_dict[k] = v
        
    return imp_dict

In [None]:
def feature_importance_report(df: pd.DataFrame, weight_col, feature_col):
    # Convert the dictionaries contained with the feature_col dicts into dataframes which can be stacked
    raw_dfs = []
    weighted_dfs = []
    for r in df.iterrows():
        rvals = r[1]
        tmp_df = pd.DataFrame.from_dict({k: [v] for k, v in rvals[feature_col].items()})
        weight = rvals[weight_col]
        raw_dfs.append(tmp_df)
        weighted_dfs.append(tmp_df * weight)

    # Stack the dataframes
    raw_feature_imps = pd.concat(raw_dfs).fillna(0)
    weighted_feature_imps = pd.concat(weighted_dfs).fillna(0)

    # Interpret the results into a clean report
    feature_imp_report = {
        "Mean (Raw)": raw_feature_imps.mean(),
        "STD (Raw)": raw_feature_imps.std(),
        "Mean (Performance Weighted)": weighted_feature_imps.mean(),
        "STD (Performance Weighted)": weighted_feature_imps.std(),
    }
    result_df = pd.DataFrame.from_dict(feature_imp_report)
    
    return result_df

In [None]:
def feature_imp_report(df: pd.DataFrame, feature_col, weight_col) -> pd.DataFrame:
    # Convert the dictionaries contained with the feature_col dicts into dataframes which can be stacked
    raw_dfs = []
    weighted_dfs = []
    for r in df.iterrows():
        rvals = r[1]
        tmp_df = pd.DataFrame.from_dict({k: [v] for k, v in rvals[feature_col].items()})
        raw_dfs.append(tmp_df)

    # Stack the dataframes
    raw_feature_imps = pd.concat(raw_dfs).fillna(0)

    # Query the weights list a single time to avoid repeated querying expense
    weights = df[weight_col]
    
    # For each feature, calculate our desired statistics
    return_cols = ['Mean', 'STD', 'Weighted Mean', 'Weighted STD']
    return_df_dict = {}
    for c in raw_feature_imps.columns:
        # Single query of the dataframe, as pandas can be slow w/ repeated queries
        samples = raw_feature_imps[c]
        # Raw Mean
        c_mean = np.mean(samples)
        # Raw STD
        c_std = np.std(samples)
        # Weighted mean
        c_mean_weighted = np.average(samples, weights=weights)
        # Weighted STD
        c_std_weighted = weighted_std(samples, weights)
        # Stack them into a list and store it in the dictionary
        return_df_dict[c] = [c_mean, c_std, c_mean_weighted, c_std_weighted]

    # Return the result as a dataframe
    return pd.DataFrame.from_dict(return_df_dict, columns=return_cols, orient='index')

## Setup

In [None]:
# Isolate and stack the information relative to the value
sub_dfs = []

for df in df_map.values():
    tmp_df = df.loc[:, [*study_idxs, *analysis_idx, 'balanced_accuracy (test)', 'importance_by_permutation (test)']]
    sub_dfs.append(tmp_df)

feature_imp_df = pd.concat(sub_dfs)

# Isolate only the best trial from each replicate
feature_imp_df = feature_imp_df.sort_values('balanced_accuracy (test)').groupby([*analysis_idx, 'replicate']).tail(1).set_index(analysis_idx)

# Parse the feature importance list into a cleaner dictionary
feature_imp_df['importance_by_permutation (test)'] = feature_imp_df['importance_by_permutation (test)'].apply(format_feature_imp)
feature_imp_df

In [None]:
# Isolate PCA-derived features from the rest
pca_feature_imp_df = feature_imp_df.reset_index().loc[feature_imp_df.reset_index()['prep'].apply(lambda x: 'pca' in x), :].set_index([*analysis_idx])
pca_feature_imp_df

In [None]:
nonpca_feature_imp_df = feature_imp_df.drop(pca_feature_imp_df.index)
nonpca_feature_imp_df

## Un-transformed Features

In [None]:
full_feature_imp_df = nonpca_feature_imp_df.query("dataset == 'full'")
feature_imp_report(full_feature_imp_df, 'importance_by_permutation (test)', 'balanced_accuracy (test)')

### Full dataset (all possible features)

In [None]:
full_feature_imp_df = nonpca_feature_imp_df.query("dataset == 'full'")
full_feature_report = feature_imp_report(full_feature_imp_df, 'importance_by_permutation (test)', 'balanced_accuracy (test)')
full_feature_report.sort_values("Mean", ascending=False).head(10)

### Image-derived features only

In [None]:
img_feature_imp_df = nonpca_feature_imp_df.query("dataset == 'img_only'")
img_feature_report = feature_imp_report(img_feature_imp_df, 'importance_by_permutation (test)', 'balanced_accuracy (test)')
img_feature_report.sort_values("Mean", ascending=False).head(10)

### Clinical Features

In [None]:
clin_feature_imp_df = nonpca_feature_imp_df.query("dataset == 'clinical'")
clin_feature_report = feature_imp_report(clin_feature_imp_df, 'importance_by_permutation (test)', 'balanced_accuracy (test)')
clin_feature_report.sort_values("Mean", ascending=False).head(10)

### PCA

In [None]:
full_pca_imp_df = pca_feature_imp_df.query("dataset == 'full'")
full_pca_report = feature_imp_report(full_pca_imp_df, 'importance_by_permutation (test)', 'balanced_accuracy (test)')
full_pca_report.sort_values("Mean", ascending=False).head(10)

In [None]:
img_pca_imp_df = pca_feature_imp_df.query("dataset == 'img_only'")
img_pca_report = feature_imp_report(img_pca_imp_df, 'importance_by_permutation (test)', 'balanced_accuracy (test)')
img_pca_report.sort_values("Mean", ascending=False).head(10)

In [None]:
clin_pca_imp_df = pca_feature_imp_df.query("dataset == 'clinical'")
clin_pca_report = feature_imp_report(clin_pca_imp_df, 'importance_by_permutation (test)', 'balanced_accuracy (test)')
clin_pca_report.sort_values("Mean", ascending=False).head(10)