In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
matplotlib.style.use(snakemake.input.mpl_style)

In [None]:
def plot_grouped_barplot(scores, metric, metrics_pretty_dict,output_path=None,palette=None,ax=None, legend=True):
    """Plot grouped barplots of scores across datasets."""

    if output_path is None and ax is None:
        raise ValueError("Either output_path or ax must be provided")
    if output_path is not None and ax is not None:
        raise ValueError("Only one of output_path or ax can be provided")
    
    scores_longform = scores.reset_index().rename(columns={'index': 'Model'})
    scores_longform = pd.melt(
        scores_longform,
        id_vars=['Model'],
        value_vars=scores.columns,
        var_name='Dataset',
        value_name='Score'
    )

    if ax is None:
        plt.figure(figsize=(6, 3))
        ax= plt.gca()
    else:
        plt.sca(ax)
    sns.barplot(
        x='Dataset',
        y='Score',
        hue='Model',
        data=scores_longform,
        palette=sns.color_palette("Greys", scores_longform['Model'].nunique()) if palette is None else palette,
        ax=ax,
    )

    # Add the performance values above the bars (rotated by 45 degrees)
    for p in plt.gca().patches:
        if p.get_width()==0:
            continue
        plt.annotate(
            f' {p.get_height():.2f}',
            (p.get_x() + p.get_width() / 2., p.get_height()),
            ha='center',
            va='bottom',
            rotation=90,
            fontsize=6,
            color='black'
        )

    #plt.title('Model Performance Across Datasets')
    plt.ylabel(metrics_pretty_dict[metric])
    plt.xticks(rotation=45, ha='right')
    if legend:
        plt.legend(title='Model', bbox_to_anchor=(1.05, 1), loc='upper left')
    else:
        plt.legend().remove()
    plt.tight_layout()
    plt.ylim(0, 1)
    plt.xlabel("")
    

    if output_path is not None:
        plt.savefig(output_path)
        plt.close()


In [None]:
metrics_translation_dict = {"auroc":"rocauc_macroAvg",
                            "accuracy":"accuracy_macroAvg",
                            "f1":"f1_macroAvg"
                            }
metrics_pretty_dict = {"auroc":"AUROC", 
                        "accuracy":"Accuracy",
                        "f1":"F1 Score"
                        }

datasets_pretty_dict = {"tabula_sapiens":"Tabula Sapiens", 
                        "pancreas":"Pancreas",
                        "immgen":"ImmGen",
                        "aida":"AIDA",
                        "tabula_sapiens_well_studied_celltypes": "Tabula Sapiens\n(20 common cell types)",
                        }

models_pretty_dict = {"uce":"UCE",
                      "geneformer":"Geneformer",
                      "scgpt":"scGPT",
                      "CellWhisperer":"CellWhisperer",
                      "CellAssign":"CellAssign",
                      "cellwhisperer_clip_v1":"CellWhisperer (Geneformer-based)",
                      "cellwhisperer_clip_v2_scgpt":"CellWhisperer (scGPT-based)",
                      "cellwhisperer_clip_v2_uce":"CellWhisperer (UCE-based)",
                      }

training_options_pretty_dict = {"unfrozen":"unfrozen, with pseudo-bulking",
                                "frozen_singlecells":"frozen, without pseudo-bulking",
                                "frozen":"frozen, with pseudo-bulking"}


palette = {
    # was: #A6CEE3, #1F78B4, #08306B
    'UCE (unfrozen, with pseudo-bulking)':'#CAB2D6',  # light purple
    'UCE (frozen, without pseudo-bulking)': '#6A3D9A',  # medium purple
    'UCE (frozen, with pseudo-bulking)': '#3F007D',    # deep purple

    "CellAssign": "grey",

    # was: #B2DF8A, #33A02C, #006400
    'scGPT (unfrozen, with pseudo-bulking)': '#CAB2D6',  # light purple
    'scGPT (frozen, without pseudo-bulking)':'#6A3D9A',  # medium purple
    'scGPT (frozen, with pseudo-bulking)': '#3F007D',    # deep purple

    'Geneformer (unfrozen, with pseudo-bulking)': '#CAB2D6',  # light purple
    'Geneformer (frozen, without pseudo-bulking)': '#6A3D9A',  # medium purple
    'Geneformer (frozen, with pseudo-bulking)': '#3F007D',    # deep purple

    "CellWhisperer (Geneformer-based)": '#ee9703', # orange, as for the cosine distance
    "CellWhisperer (scGPT-based)": '#ee9703',# orange, as for the cosine distance
    "CellWhisperer (UCE-based)": '#ee9703',# orange, as for the cosine distance

}

In [None]:
cw_model_name_pretty = models_pretty_dict[snakemake.wildcards.model]

fig, axes = plt.subplots(len(snakemake.params.metrics),1, figsize=(6,0.8*len(snakemake.params.metrics)), sharex=True)

for i, metric in enumerate(snakemake.params.metrics):
        
    cw_scores = {} # For a given model and metric, this will contain the scores for all datasets
    for label_col, dataset_macroavg_path, dataset_name in zip(snakemake.params.label_cols, snakemake.input.cw_macroaverages, snakemake.params.datasets):
        assert label_col == "celltype", f"Label column {label_col} is not supported. Only 'celltype' is allowed."
        cw_macroavg_performance_df=pd.read_csv(dataset_macroavg_path, index_col=0)
        cw_scores[dataset_name]=float(cw_macroavg_performance_df.loc[metrics_translation_dict[metric]].item().replace("tensor(","").replace(")",""))
    cw_df = pd.DataFrame.from_dict(cw_scores, orient='index', columns=[cw_model_name_pretty])

    marker_based_method_scores = {}
    for dataset, marker_based_method_performance in zip(snakemake.params.datasets, snakemake.input.marker_based_method_performances):
        marker_based_method_scores[dataset] = pd.read_csv(marker_based_method_performance, index_col=0).loc[metric].item()
    marker_based_method_df = pd.DataFrame.from_dict(marker_based_method_scores, orient='index', columns=['CellAssign'])

    all_scores = []
    for j, training_option in enumerate(snakemake.params.training_options):
    
        # Calculate the index in the flattened list
        idx = i * len(snakemake.params.training_options) + j
        aggregated_file = snakemake.input.aggregated_predictions_finetuned_models[idx]

        assert metric in aggregated_file and training_option in aggregated_file, f"Missing metric {metric} or training option {training_option} in {aggregated_file}"

        # Read the aggregated predictions for the current metric and training option
        finetuned_scores = pd.read_csv(aggregated_file, index_col=0)

        # Keep only rows that match the model_name
        keep_rows = [x for x in finetuned_scores.index if x in cw_model_name_pretty.lower()]
        finetuned_scores = finetuned_scores.loc[keep_rows]

        # Merge the two dataframes on the index
        if j==0:
            scores = pd.merge(finetuned_scores.T, cw_df, left_index=True, right_index=True).T
            scores = pd.merge(scores.T, marker_based_method_df, left_index=True, right_index=True).T
            scores.index = [f"{models_pretty_dict[rowname]} ({training_options_pretty_dict[training_option]})" for rowname in scores.index[:-2]] + [cw_model_name_pretty, "CellAssign"]
        else:
            scores=finetuned_scores
            scores.index = [f"{models_pretty_dict[rowname]} ({training_options_pretty_dict[training_option]})" for rowname in scores.index]
        # Append the scores to the list
        all_scores.append(scores)

    # Concatenate all scores into a single DataFrame
    all_scores_df = pd.concat(all_scores, axis=0)   
    all_scores_df = all_scores_df.reindex(sorted(all_scores_df.index,reverse=True), axis=0) # nice ordering for plotting

    all_scores_df.columns = [datasets_pretty_dict[col] for col in all_scores_df.columns]

    # Also order the columns nicely
    all_scores_df = all_scores_df.reindex(sorted(all_scores_df.columns, reverse=True), axis=1)

    all_scores_df.to_csv(snakemake.output.scores_across_training_options[i], index=True, header=True)

    # Plot the cross-training option plots:
    plot_grouped_barplot(scores=all_scores_df,
                         metric=metric, 
                         metrics_pretty_dict=metrics_pretty_dict,
                         output_path=snakemake.output.barplots_across_training_options[i],
                         palette=palette)

    plot_grouped_barplot(scores=all_scores_df,
                         metric=metric,
                         metrics_pretty_dict=metrics_pretty_dict,
                         output_path=None,
                         palette=palette,ax=axes[i], legend=i==0)
    

plt.sca(axes[0])
plt.subplots_adjust(hspace=0.65)
plt.savefig(snakemake.output.barplots_across_training_options_across_metrics, bbox_inches='tight')

