In [59]:
import os
from os import listdir
from os.path import isfile, join, dirname, basename
import pandas as pd

In [60]:
import matplotlib.pyplot as plt
import seaborn as sns

In [61]:
import matplotlib.gridspec as gridspec

In [62]:
# from analysis_utils import collect_cv_metrics, map_groups, plot_cv_metrics

In [63]:
%run analysis_utils.py

<Figure size 640x480 with 0 Axes>

In [64]:
def collect_auprc_scores(metric='AUPRC', clf='avg', base_dir='results'):
    all_results = []

    for subdir in os.listdir(base_dir):
        if subdir.startswith('__'):
            print(f'skipping {subdir}')
            continue
        if subdir.startswith('luad1'):
            print(f'skipping {subdir}')
            continue
        exp_path = os.path.join(base_dir, subdir)
        file_path = os.path.join(exp_path, f'{metric}_{clf}.csv')
        if os.path.isfile(file_path):
            try:
                df = pd.read_csv(file_path)
                df['experiment_name'] = subdir  # Add experiment name
                all_results.append(df[['experiment_name', 'experiment', 'group', 'mean']])
            except Exception as e:
                print(f"Failed to read {file_path}: {e}")

    if not all_results:
        raise ValueError(f"No valid {metric} files found.")

    combined_df = pd.concat(all_results)
    return combined_df


In [65]:
def process_and_sort(df):
    model_means = df.groupby('experiment')['mean'].mean().reset_index(name='mean_overall')
    # model_means = df.groupby('experiment')['mean'].median().reset_index(name='mean_overall')

    model_groups = df[['experiment', 'group']].drop_duplicates()

    sorted_models = model_means.merge(model_groups, on='experiment')
    sorted_models = sorted_models.sort_values(by=['group', 'mean_overall'], ascending=[True, False])
    sorted_model_order = sorted_models['experiment'].tolist()

    heatmap_df = df.pivot(index='experiment_name', columns='experiment', values='mean')
    heatmap_df = heatmap_df[sorted_model_order]

    return heatmap_df, sorted_models

In [66]:

def plot_composite_by_group_panels(heatmap_df, sorted_models, metric, clf='avg', save_dir='.'):
    import matplotlib.pyplot as plt
    import seaborn as sns
    import matplotlib.gridspec as gridspec
    import numpy as np
    from os.path import join
    from matplotlib.cm import get_cmap
    from matplotlib.colors import Normalize

    sorted_models = sorted_models.round(2)
    heatmap_df = heatmap_df.round(2)

    # Calculate row-wise averages and sort
    row_avg = heatmap_df.mean(axis=1)
    heatmap_df = heatmap_df.loc[row_avg.sort_values(ascending=False).index]

    grouped = sorted_models.groupby('group')
    n_groups = len(grouped)
    palette = sns.color_palette("husl", n_groups)
    color_map = dict(zip(grouped.groups.keys(), palette))

    global_vmin = heatmap_df.min().min()
    global_vmax = heatmap_df.max().max()
    norm = Normalize(vmin=global_vmin, vmax=global_vmax)
    cmap = get_cmap('viridis')

    total_models = sum(len(group_df) for _, group_df in grouped)
    fig = plt.figure(figsize=(max(6, total_models * 0.5 + 1.5), max(6, heatmap_df.shape[0] * 0.6)))

    group_model_counts = [len(group_df) for _, group_df in grouped]

    # gs = gridspec.GridSpec(
    #     3,
    #     n_groups + 2,
    #     width_ratios=group_model_counts + [0.2, 1.0],
    #     height_ratios=[1.2, 6, 0.4],
    #     hspace=0.3,
    #     wspace=0.1
    # )

    gs = gridspec.GridSpec(
        3,
        n_groups + 2,  # +2 for row avg and colorbar
        width_ratios=group_model_counts + [1.0, 0.2],  # row avg, colorbar
        height_ratios=[1.2, 6, 0.4],
        hspace=0.3,
        wspace=0.1
    )
        
    for i, (group_name, group_df) in enumerate(grouped):
        # cbar = (i == n_groups - 1)
        # cbar_ax = fig.add_subplot(gs[1, -2]) if cbar else None

        cbar = (i == n_groups - 1)
        cbar_ax = fig.add_subplot(gs[1, -1]) if cbar else None  # now last column
        
        color = color_map[group_name]
        model_names = group_df['experiment'].tolist()
        model_avg = group_df['mean_overall'].tolist()
        group_heatmap = heatmap_df[model_names]

        # Top bar plot
        ax_top = fig.add_subplot(gs[0, i])
        bars = ax_top.bar(range(len(model_avg)), model_avg, color=color)
        ax_top.set_xticks([])
        ax_top.set_yticks([])
        # ax_top.set_title(f"{group_name}\nAvg {metric}", fontsize=6)
        ax_top.set_ylim(0.7, 1)

        # Remove all spines (bounding boxes)
        for spine in ax_top.spines.values():
            spine.set_visible(False)

        for idx, bar in enumerate(bars):
            height = bar.get_height()
            ax_top.text(bar.get_x() + bar.get_width() / 2, height + 0.01, f"{height:.2f}", ha='center', va='bottom', fontsize=6)

        # Heatmap
        ax_heatmap = fig.add_subplot(gs[1, i])
        sns.heatmap(
            group_heatmap,
            ax=ax_heatmap,
            annot=True,
            fmt=".2f",
            cmap='viridis',
            vmin=global_vmin,
            vmax=global_vmax,
            cbar=cbar,
            cbar_ax=cbar_ax,
            xticklabels=False,
            annot_kws={"size": 6}
        )
        ax_heatmap.set_xticklabels([])
        ax_heatmap.tick_params(axis='x', bottom=False)
        ax_heatmap.set_xlabel(group_name)
        if i != 0:
            ax_heatmap.set_ylabel("")
            ax_heatmap.set_yticks([])
            ax_heatmap.set_yticklabels([])

        xticklabels = model_names
        ax_heatmap.tick_params(axis='x', rotation=90, labelsize=8)
        ax_heatmap.set_ylabel('')

        # Bottom color bar
        ax_bottom = fig.add_subplot(gs[2, i])
        for j in range(len(model_names)):
            ax_bottom.add_patch(plt.Rectangle((j, 0), 1, 1, color=color))
            
        ax_bottom.set_xlim(0, len(model_names))
        ax_bottom.set_ylim(0, 1)
        ax_bottom.axis('off')
        for j, model in enumerate(model_names):
            ax_bottom.text(j + 0.5, -0.2, model, ha='right', va='top', rotation=90, fontsize=10, rotation_mode='anchor')

     # Right-side bar plot: row average
    # ax_right_bar = fig.add_subplot(gs[1, -1])
    # avg_vals = heatmap_df.mean(axis=1)
    # y_positions = np.arange(len(avg_vals))
    
    # Right-side bar plot: row average (now second-to-last column)
    ax_right_bar = fig.add_subplot(gs[1, -2])
    avg_vals = heatmap_df.mean(axis=1)
    y_positions = np.arange(len(avg_vals))
    colors = [cmap(norm(val)) for val in avg_vals]

    ax_right_bar.barh(y_positions, avg_vals.values, color=colors)

    ax_right_bar.invert_yaxis()  # Align with heatmap rows
    ax_right_bar.set_xlim(0.4, 1.0)
    ax_right_bar.set_yticks([])
    ax_right_bar.set_xticks([])
    ax_right_bar.set_xlabel(f"Avg {metric}", fontsize=5)

    for spine in ax_right_bar.spines.values():
        spine.set_visible(False)
        

    # Use heatmap order and colormap for coloring
    colors = [cmap(norm(val)) for val in avg_vals]

    bars = ax_right_bar.barh(y_positions, avg_vals.values, color=colors)
    
    if cbar_ax is not None:
        cbar_ax.tick_params(labelsize=6)  # Control colorbar tick font size here

    # Align and remove ticks
    # ax_right_bar.invert_yaxis()  # Align top-down with heatmap
    # ax_right_bar.set_xlim(0.7, 1.0)
    # ax_right_bar.set_yticks([])
    # ax_right_bar.set_xticks([])
    # ax_right_bar.set_xlabel(f"Avg {metric}", fontsize=8)

    # Remove border/spines
    for spine in ax_right_bar.spines.values():
        spine.set_visible(False)

    # Group legend (optional)
    fig.subplots_adjust(bottom=0.25)
    legend_handles = [plt.Line2D([0], [0], color=c, lw=10) for c in color_map.values()]
    # fig.legend(legend_handles, color_map.keys(), loc='lower center', ncol=n_groups)

    output_path = join(save_dir, f'{metric}_{clf}.png')
    fig.savefig(output_path, dpi=200, bbox_inches='tight')
    plt.close()
    print(f"Grouped panel heatmap saved to {output_path}")
    return fig

In [67]:
metrics= ['AUC', 'F1', 'AUPRC', 'Recall', 'Precision', 'Accuracy']

In [68]:
clfs = ['vote', 'mil', 'avg'] #['avg', 'mil']

In [69]:
# collect_auprc_scores

In [70]:
# group_map = {'geneformer':'Geneformer', 'baseline': 'Baseline', 'scgpt': 'scGPT'}
# 

In [71]:
# heatmap_df

In [72]:
# heatmap_df.to_csv('test.csv')

In [73]:
# metrics

In [74]:
# model_name_map

In [75]:
# test_df = collect_auprc_scores(metric='AUPRC',clf='avg', base_dir='./metrics')


In [76]:
# test_df

In [77]:
for clf in clfs:
    print(clf)
    for m in metrics:
        print(m)
        heatmap_df = collect_auprc_scores(metric=m,clf=clf, base_dir='./metrics')
        idx = heatmap_df.experiment.str.contains('continue')
        heatmap_df= heatmap_df[~idx]
        idx = heatmap_df.experiment.str.contains('freez')
        heatmap_df= heatmap_df[~idx]
        
        idx = heatmap_df.experiment.str.endswith('k')
        heatmap_df= heatmap_df[~idx]
        
        idx = heatmap_df.experiment.str.endswith('batch')
        heatmap_df= heatmap_df[~idx]
    
        idx = heatmap_df.experiment.str.endswith('_')
        heatmap_df= heatmap_df[~idx]
        
        idx = heatmap_df.experiment.str.endswith('all')
        heatmap_df= heatmap_df[~idx]
        
        # heatmap_df.group = heatmap_df.group.map(group_map)
        heatmap_df.experiment = heatmap_df.experiment.map(lambda x: model_name_map.get(x, x))
        heatmap_df.experiment_name = heatmap_df.experiment_name.map(experiment_name_map)
        # heatmap_df['experiment_name'] = heatmap_df.experiment
        heatmap_df_grouped, sorted_models = process_and_sort(heatmap_df)
        plot_composite_by_group_panels(heatmap_df_grouped, sorted_models, m, clf, save_dir='./plots')
        

vote
AUC
skipping luad1


  cmap = get_cmap('viridis')


Grouped panel heatmap saved to ./plots/AUC_vote.png
F1
skipping luad1


  cmap = get_cmap('viridis')


Grouped panel heatmap saved to ./plots/F1_vote.png
AUPRC
skipping luad1


  cmap = get_cmap('viridis')


Grouped panel heatmap saved to ./plots/AUPRC_vote.png
Recall
skipping luad1


  cmap = get_cmap('viridis')


Grouped panel heatmap saved to ./plots/Recall_vote.png
Precision
skipping luad1


  cmap = get_cmap('viridis')


Grouped panel heatmap saved to ./plots/Precision_vote.png
Accuracy
skipping luad1


  cmap = get_cmap('viridis')


Grouped panel heatmap saved to ./plots/Accuracy_vote.png
mil
AUC
skipping luad1


  cmap = get_cmap('viridis')


Grouped panel heatmap saved to ./plots/AUC_mil.png
F1
skipping luad1


  cmap = get_cmap('viridis')


Grouped panel heatmap saved to ./plots/F1_mil.png
AUPRC
skipping luad1


  cmap = get_cmap('viridis')


Grouped panel heatmap saved to ./plots/AUPRC_mil.png
Recall
skipping luad1


  cmap = get_cmap('viridis')


Grouped panel heatmap saved to ./plots/Recall_mil.png
Precision
skipping luad1


  cmap = get_cmap('viridis')


Grouped panel heatmap saved to ./plots/Precision_mil.png
Accuracy
skipping luad1


  cmap = get_cmap('viridis')


Grouped panel heatmap saved to ./plots/Accuracy_mil.png
avg
AUC
skipping luad1


  cmap = get_cmap('viridis')


Grouped panel heatmap saved to ./plots/AUC_avg.png
F1
skipping luad1


  cmap = get_cmap('viridis')


Grouped panel heatmap saved to ./plots/F1_avg.png
AUPRC
skipping luad1


  cmap = get_cmap('viridis')


Grouped panel heatmap saved to ./plots/AUPRC_avg.png
Recall
skipping luad1


  cmap = get_cmap('viridis')


Grouped panel heatmap saved to ./plots/Recall_avg.png
Precision
skipping luad1


  cmap = get_cmap('viridis')


Grouped panel heatmap saved to ./plots/Precision_avg.png
Accuracy
skipping luad1


  cmap = get_cmap('viridis')


Grouped panel heatmap saved to ./plots/Accuracy_avg.png
