# Plotting box plots for Tissue

In [None]:
# Let's plot it myself here, based on the updated instructions: 
# - Each plot separately
# - Separate colors for each category
# - Prettier style

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Load data
df = pd.read_csv("/work/magroup/nzh/Heimdall-dev/plotting/tissue.csv")

# Preprocessing
df = df.dropna()
df.loc[df['F_Expression'] == 'nonzero_2nn', 'F_Expression'] = 'Continuous'
df.loc[df['F_Expression'] == 'scfoundation', 'F_Expression'] = 'Autobin'
df.loc[df['F_Expression'] == 'sorting', 'F_Expression'] = 'Sorting'
df.loc[df['F_Expression'] == 'binning', 'F_Expression'] = 'Binning'

df.loc[df['F_Gene'] == 'identity', 'F_Gene'] = 'Random'
df.loc[df['F_Gene'] == 'pca_esm2', 'F_Gene'] = 'ESM2'
df.loc[df['F_Gene'] == 'pca_gene2vec', 'F_Gene'] = 'Gene2Vec'
df.loc[df['F_Gene'] == 'pca_hyenadna', 'F_Gene'] = 'HyenaDNA'
df.loc[df['F_Gene'] == 'pca_genept', 'F_Gene'] = 'GenePT'



# df = df[df["F_Expression"] != "sorting"]

# df = df[~df.isin(['dummy']).any(axis=1)]

# Define order of F_Gene
gene_order = ['HyenaDNA', 'ESM2', 'Random', 'GenePT', 'Gene2Vec']

# Setup plotting aesthetics
sns.set(style="white", context="talk")
palette = sns.color_palette("Set2", n_colors=df["F_Gene"].nunique())

plot_configs = [
    ('Split1', False, '$F_{GENE}$ over Tissue Split - Only Nonzero Genes'),
    ('Split1', True, '$F_{GENE}$ over Tissue Split - All Genes'),
]

# Mapping genes to specific color
gene_palette = dict(zip(gene_order, palette))

# Create separate plots
for dataset, all_genes, title in plot_configs:
    subset = df[(df['Dataset'] == dataset) & (df['All_Genes'] == all_genes)]

    dummy_baseline_rows = subset[(subset['F_Gene'] == 'dummy') & (subset['F_Expression'] == 'dummy')]
    if dummy_baseline_rows.empty:
        print("Warning: No dummy baseline found. No baseline line will be plotted.")
        baseline_value = None
    else:
        baseline_value = dummy_baseline_rows['Test Score'].mean()
        print(f"Baseline (mean Test Score where dummy) = {baseline_value:.4f}")


    if subset.empty:
        print(f"No data for {title}")
        continue
    
    plt.figure(figsize=(8, 6))
    
    sns.boxplot(
        data=subset,
        x='F_Gene', y='Test Score',
        order=gene_order,
        whis=np.inf,
        width=0.6,
        palette=gene_palette,
        fliersize=0
    )
    
    sns.swarmplot(
        data=subset,
        x='F_Gene', y='Test Score',
        order=gene_order,
        palette=gene_palette,
        size=8,
        linewidth=0.5,
        edgecolor='black'
    )
    
    min_score = subset['Test Score'].min()
    max_score = subset['Test Score'].max()
    margin = (max_score - min_score) * 0.1
    # plt.ylim(min_score - margin, max_score + margin)

    # --- Add a red dotted horizontal line at baseline value
    if baseline_value is not None:
        plt.axhline(
            y=baseline_value,
            color='red',
            linestyle='dotted',
            linewidth=2,
            label='Optimized Linear Baseline'
        )
        plt.legend(loc='lower right', fontsize=12)

    
    plt.title(title, fontsize=20)

    plt.xlabel("$F_{Gene}$", fontsize=20, weight='bold')
    plt.ylabel("MCC Test Score", fontsize=16, weight='bold')
    plt.xticks(rotation=45, ha='right', fontsize=16, weight="bold")
    plt.yticks(fontsize=20)
    
    sns.despine()
    plt.tight_layout()
    plt.show()



# Stratified Box Plots

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Patch

# 1) Load & clean --------------------------------------------------------------------
def load_and_clean(path):
    df = pd.read_csv(path).dropna().copy()
    df.loc[df['F_Expression']=='nonzero_2nn','F_Expression'] = 'Continuous'
    df.loc[df['F_Expression']=='scfoundation','F_Expression'] = 'Autobin'
    df.loc[df['F_Expression']=='sorting',      'F_Expression'] = 'Sorting'
    df.loc[df['F_Expression']=='binning',      'F_Expression'] = 'Binning'
    df.loc[df['F_Gene']=='identity',           'F_Gene'] = 'Random'
    df.loc[df['F_Gene']=='pca_esm2',           'F_Gene'] = 'ESM2'
    df.loc[df['F_Gene']=='pca_gene2vec',       'F_Gene'] = 'Gene2Vec'
    df.loc[df['F_Gene']=='pca_hyenadna',       'F_Gene'] = 'HyenaDNA'
    df.loc[df['F_Gene']=='pca_genept',         'F_Gene'] = 'GenePT'
    return df

# 2) Stratified & hatched plotting ---------------------------------------------------
def plot_overlap(
    df, overlap, all_genes,
    gene_order, expr_order, gene_palette, hatch_map,
    box_width=0.2, spacing=0.05, extra_gap=0.2
):
    # filter on Overlap35/113
    sub = df[(df.Dataset==overlap)&(df.All_Genes==all_genes)]
    if sub.empty:
        print(f"No data for {overlap}, All_Genes={all_genes}")
        return

    # compute offsets so boxes don't overlap
    n = len(expr_order)
    offsets    = (np.arange(n) - (n-1)/2) * (box_width + spacing)
    group_total = n*box_width + (n-1)*spacing
    gene_sep    = group_total + extra_gap
    centers     = np.arange(len(gene_order)) * gene_sep

    plt.figure(figsize=(10,6))
    for gi, gene in enumerate(gene_order):
        for ei, expr in enumerate(expr_order):
            scores = sub.query("F_Gene==@gene and F_Expression==@expr")['Test Score']
            if scores.empty:
                continue
            x = centers[gi] + offsets[ei]

            plt.boxplot(
                scores,
                positions=[x],
                widths=box_width,
                patch_artist=True,
                boxprops=dict(
                    facecolor=gene_palette[gene],
                    edgecolor='k',
                    hatch=hatch_map[expr]
                ),
                whiskerprops=dict(color='k'),
                capprops=dict(color='k'),
                medianprops=dict(color='k'),
                flierprops=dict(
                    marker='o',
                    markerfacecolor=gene_palette[gene],
                    markeredgecolor='k',
                    markersize=5,
                    linestyle='none'
                )
            )
            jitter = np.random.normal(scale=box_width*0.15, size=len(scores))
            plt.scatter(
                x + jitter,
                scores,
                color=gene_palette[gene],
                edgecolor='k',
                s=40,
                alpha=0.8,
                zorder=3
            )

    # title / labels / limits
    title_flag = "All Genes" if all_genes else "Expressed Genes"


    title = f"$F_{{Gene}}$ Over Tissue Split - {title_flag}"
    # title = f"$F_{{Gene}}$ Over {overlap} - {title_flag}"
    plt.title(title, fontsize=20)
    plt.xlabel("$F_{Gene}$", fontsize=20, weight='bold')
    plt.ylabel("MCC Test Score", fontsize=16, weight='bold')
    plt.xticks(centers, gene_order, rotation=45, ha='right', fontsize=16, weight="bold")
    plt.yticks(fontsize=20)

    # y‚Äêaxis limits
    if overlap == "Split1":
        plt.ylim(0.3, 0.5)
    else:
        plt.ylim(0, 0.7)

    # legend for hatches
    handles = [
        Patch(facecolor='lightgray', edgecolor='k', hatch=hatch_map[e], label=e)
        for e in expr_order
    ]
    plt.legend(
        handles=handles,
        title="F_Expression",
        frameon=False,
        bbox_to_anchor=(1.02,1),
        title_fontsize=16,
        fontsize=14
    )

    sns.despine()
    plt.tight_layout()
    plt.show()


# 3) Run -------------------------------------------------------------------------------
df = load_and_clean("/work/magroup/nzh/Heimdall-dev/plotting/tissue.csv")

gene_order   = ['HyenaDNA','ESM2','Random','GenePT','Gene2Vec']
expr_order   = ['Sorting','Binning','Autobin','Continuous']
gene_palette = dict(zip(gene_order, sns.color_palette("Set2", len(gene_order))))

hatch_map = {
    'Sorting'   : '///',
    'Binning'   : '\\\\\\',
    'Autobin'   : 'xxx',
    'Continuous': '...',
}

for overlap in ['Split1']:
    for flag in [True, False]:
        plot_overlap(
            df, overlap, flag,
            gene_order, expr_order,
            gene_palette, hatch_map,
            box_width=0.2, spacing=0.05, extra_gap=0.2
        )
