In [None]:
# Third-party
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
from scipy.stats import ttest_rel, wilcoxon, shapiro
from statsmodels.stats.multitest import multipletests
import matplotlib.font_manager as fm
from __future__ import annotations
import numpy as np
import pandas as pd
import seaborn as sns
import scanpy as sc
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
from matplotlib.patches import Patch
from matplotlib.lines import Line2D

In [None]:
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['svg.fonttype'] = 'none'

fm.fontManager.addfont('/work/magroup/skrieger/Arial.ttf')
matplotlib.rcParams['font.family'] = 'sans-serif'
matplotlib.rcParams['font.sans-serif'] = ['Arial']

sc.set_figure_params(figsize=(6, 6), frameon=False)
sns.set_theme(style="white")  

<h2>Panel B</h2>

In [None]:

# Fixed method order and set names (used for consistent plotting)
method_order = ["scbert", "scgpt", "uce", "scfound", "geneformer"]
set_names_default = ("All genes", "Expressed genes")

def plot_mcc_grouped(
    mcc_all_genes: dict,
    mcc_expressed_genes: dict,
    title: str = "MCC Comparison (All vs Expressed)",
    save_prefix: str = "mcc_grouped",
    set_names: tuple[str, str] = set_names_default,
    baseline_all: float | None = None,
    baseline_expressed: float | None = None,
    figsize: tuple[float, float] = (7.5, 6),
):
    """
    Grouped bar plot comparing MCC across methods for two settings:
    'All genes' vs 'Expressed genes'. Colors encode METHODS (Set2),
    textures (hatches) encode SET ('All' vs 'Expressed').

    Parameters
    ----------
    mcc_all_genes : dict[str, Sequence[float]]
    mcc_expressed_genes : dict[str, Sequence[float]]
    title : str
    save_prefix : str
    set_names : (str, str)
    baseline_all : float | None
    baseline_expressed : float | None
    figsize : (float, float)
    """
    # Illustrator-friendly font embedding
    matplotlib.rcParams["pdf.fonttype"] = 42
    matplotlib.rcParams["ps.fonttype"]  = 42

    pretty_name = {
        "scbert":     "scBERT",
        "scgpt":      "scGPT",
        "uce":        "UCE",
        "scfound":    "scFoundation",
        "geneformer": "Geneformer",
    }

    # ---- Build tidy DataFrame
    rows = []
    all_methods = method_order if method_order else sorted(set(mcc_all_genes) | set(mcc_expressed_genes))
    for m in all_methods:
        if m in mcc_all_genes:
            rows.extend({"method": m, "set": set_names[0], "mcc": s} for s in mcc_all_genes[m])
        if m in mcc_expressed_genes:
            rows.extend({"method": m, "set": set_names[1], "mcc": s} for s in mcc_expressed_genes[m])
    df = pd.DataFrame(rows)

    # ---- Palettes / hatches
    method_palette = dict(zip(
        all_methods,
        sns.color_palette("Set2", n_colors=max(len(all_methods), 3))
    ))
    set_hatch_map = {
        set_names[0]: "",      # All genes
        set_names[1]: "///",   # Expressed genes
    }

    # ---- Plot
    fig, ax = plt.subplots(figsize=figsize, facecolor="white")

    sns.barplot(
        data=df,
        x="method",
        y="mcc",
        hue="set",
        order=all_methods,
        hue_order=list(set_names),
        palette=["#cccccc", "#999999"],  # temp colors; we overwrite below
        errorbar="sd",
        capsize=0.1,
        err_kws={"color": "black", "linewidth": 1.25},
        ax=ax,
    )

    # Pretty method labels on x-axis
    ax.set_xticklabels([pretty_name.get(t.get_text(), t.get_text()) for t in ax.get_xticklabels()],
                       rotation=0, ha="center")

    # Re-style bars to use method colors + set hatches
    containers = ax.containers[:len(set_names)]
    for set_idx, container in enumerate(containers):
        set_i = set_names[set_idx]
        for method_idx, bar in enumerate(container):
            if method_idx >= len(all_methods):
                continue
            method_i = all_methods[method_idx]
            bar.set_facecolor(method_palette[method_i])
            bar.set_edgecolor("black")
            bar.set_hatch(set_hatch_map[set_i])

    # ---- Baselines
    baseline_handles = []
    if baseline_all is not None:
        ax.axhline(y=baseline_all, linestyle="--", color="red", linewidth=1.5)
        baseline_handles.append(Line2D([0], [0], color="red", lw=1.5, linestyle="--",
                                       label="Linear baseline (All)"))
    if baseline_expressed is not None:
        ax.axhline(y=baseline_expressed, linestyle=":", color="black", linewidth=1.5)
        baseline_handles.append(Line2D([0], [0], color="black", lw=1.5, linestyle=":",
                                       label="Linear baseline (Expr)"))

    # ---- Legends (first: sets + baselines; second: methods)
    set_handles = [
        Patch(facecolor="white", edgecolor="black", hatch=set_hatch_map[s], label=s)
        for s in set_names
    ]
    first_legend = ax.legend(
        handles=set_handles + baseline_handles,
        title="Data subset",
        frameon=False,
        bbox_to_anchor=(1.10, 1.00),
        loc="upper left",
        borderaxespad=0,
    )
    ax.add_artist(first_legend)

    method_handles = [
        Patch(facecolor=method_palette[m], edgecolor="black", label=pretty_name.get(m, m))
        for m in all_methods
    ]
    ax.legend(
        handles=method_handles,
        title="Method",
        frameon=False,
        bbox_to_anchor=(1.15, 0.45),
        loc="upper left",
        borderaxespad=0,
    )

    # ---- Titles / labels
    ax.set_title(title)
    ax.set_xlabel("")
    ax.set_ylabel("MCC")

    # =========================
    # Ticks patch (cleaned)
    # =========================
    # Define y-limits and major tick spacing
    _YMIN, _YMAX = 0.0, 1
    _YMAJOR_STEP = 0.10  # 0.0, 0.1, 0.2, ...

    ax.set_ylim(_YMIN, _YMAX)
    ax.yaxis.set_major_locator(mticker.MultipleLocator(_YMAJOR_STEP))
    ax.yaxis.set_major_formatter(mticker.FormatStrFormatter("%.1f"))

    ax.minorticks_off()  # disable minor ticks
    ax.tick_params(axis="y", which="major", length=6, width=1, direction="out")

    # Move ticks to the right and show right spine
    ax.spines["right"].set_visible(True)
    ax.yaxis.set_ticks_position("right")
    # Keep left spine for grid alignment if you prefer:
    # ax.spines["left"].set_visible(True)

    # Optional: x-axis tick styling (sane defaults)
    ax.tick_params(axis="x", which="major", length=6, width=1, direction="out")

    plt.tight_layout()

    # ---- Save & show
    if save_prefix:
        outpath = save_prefix if save_prefix.lower().endswith(".pdf") else f"{save_prefix}.pdf"
        plt.savefig(outpath, facecolor="white", bbox_inches="tight", pad_inches=0.02)

    plt.show()


In [None]:
mcc_all = {
    "scbert": [0.94867724, 0.94609857, 0.94294327],
    "scgpt": [0.94376308, 0.94817495, 0.94581705],
    "uce":    [0.93410349, 0.93321985, 0.93982255],
    "scfound": [0.94861931, 0.95554936, 0.95099777],
    "geneformer": [0.92469352, 0.92147833, 0.93449092]
}

mcc_expr = {
    "scbert": [0.95078743, 0.94896245, 0.94858998],
    "scgpt":  [0.94527137, 0.9511742, 0.95115674],
    "uce":    [0.93845016, 0.93661916, 0.93552315],
    "scfound":[0.95342636, 0.94814163, 0.9492901],
    "geneformer":[0.94382459, 0.93971032, 0.94171149]
}

plot_mcc_grouped(
    mcc_all_genes=mcc_all,
    mcc_expressed_genes=mcc_expr,
    title="Random Split MCC (All vs Expressed)",
    save_prefix="PanelC",
    set_names=("All genes", "Expressed genes"),
    baseline_all=0.93083292,
    baseline_expressed=None)


<h2>Panel C</h2>

In [None]:
mcc_all = {
    "scbert": [0.49071321, 0.49229085, 0.60595125, 0.63054377, 0.59705812],
    "scgpt": [0.46976808, 0.4866567, 0.48075321, 0.46969351, 0.47225532],
    "uce": [0.44531417, 0.43281052, 0.42603561, 0.44584182, 0.45910525],
    "scfound": [0.50881106, 0.47597343, 0.49842638, 0.49933156, 0.47142515],
    "geneformer": [0.46796301, 0.44147798, 0.39268956, 0.36072132, 0.40160018]
}
mcc_expr = {
    "scbert": [0.5486, 0.5548, 0.5146, 0.5695, 0.5613],
    "scgpt":  [0.473, 0.4425, 0.464, 0.4451, 0.5325],
    "uce":    [0.4461, 0.4615, 0.4594, 0.4588, 0.4567],
    "scfound":[0.4889, 0.4732, 0.4426, 0.4715, 0.464],
    "geneformer":[0.4858, 0.4565, 0.4502, 0.4711, 0.4633]
}

plot_mcc_grouped(
    mcc_all_genes=mcc_all,
    mcc_expressed_genes=mcc_expr,
    title="Tokenization MCC (All vs Expressed)",
    save_prefix="PanelB",
    set_names=("All genes", "Expressed genes"),
    baseline_all=0.39696,   
    baseline_expressed=None
)


<h2>UMAP</h2>

In [None]:
#scbert
test_adata = sc.read_h5ad('/path/to/adata.h5ad')

sc.pl.embedding(
    test_adata,
    basis="umap",  # this accesses obsm["X_umap"]
    color="celltypes",
    show=True,
    save="_umap.pdf"
)

<h2>Confusion Matrices</h2>

In [None]:
# Connect to the run
api = wandb.Api()
run = api.run("/path/to/run")

# Get confusion matrix metadata from the summary
cm_info = run.summary["test_confusion_matrix_table"]
artifact_file = run.file(cm_info["path"]).download(exist_ok=True)

# Load the JSON file
with open(artifact_file.name, "r") as f:
    table_data = json.load(f)

# Reconstruct DataFrame
columns = table_data["columns"]
data = table_data["data"]
df = pd.DataFrame(data, columns=columns)

# Pivot into square confusion matrix
df_pivot = df.pivot(index="Actual", columns="Predicted", values="nPredictions").fillna(0)

normalized_df = df_pivot.div(df_pivot.sum(axis=1), axis=0).fillna(0)

# Plot
plt.figure(figsize=(6, 6))
sns.heatmap(
    normalized_df,
    annot=False,
    cmap="Reds",
    xticklabels=df_pivot.columns,
    yticklabels=df_pivot.index,
    square=True,
    linewidths=0.5,
    linecolor="black",
    cbar=True
)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Test Confusion Matrix")
plt.tight_layout()
plt.savefig("/path/to/confusion.pdf", format="pdf", bbox_inches="tight")
plt.show()


<h2>scBERT Ablation</h2>

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Patch

# === Example data for Fg (yours) ===
Fg_scores = {
    "identity": [0.45751134, 0.45100579, 0.49813652, 0.46643114, 0.48736089, 0.46875244],
    "esm2":    [0.44273233, 0.44392687, 0.43577412, 0.45891109, 0.47342929, 0.43570945],
    "hyenadna": [0.44298422, 0.44606373, 0.42005789, 0.43307543, 0.41069844, 0.43457267],
    "genept": [0.47888815, 0.49243236, 0.42832363, 0.50483006, 0.48969954, 0.49101087]
}

Fe_scores = {
    "scgpt_bin": [0.55817157, 0.61578941, 0.60174245, 0.58067197, 0.53753412, 0.53839368],
    "autobin":    [0.52708822, 0.56023854, 0.58409822, 0.53551513, 0.49547338, 0.45878071],
    "continuous": [0.57259917, 0.53929651, 0.58866841, 0.58620977, 0.57529479, 0.58249724]
}

Fo_scores = {
    "exprsort": [0.57259917, 0.53929651, 0.58866841, 0.58620977, 0.57529479, 0.58249724],
    "chromsort":    [0.59442306, 0.56763101, 0.50425416, 0.48647985, 0.51687235, 0.62693185],
}

Fs_scores = {
    "weighted": [0.52629685, 0.53090668, 0.58057767, 0.57593179, 0.48153391, 0.52228451],
}

# Optional: add Fe / Fo when ready
scores_by_group = {"Fg": Fg_scores, "Fe": Fe_scores, "Fo": Fo_scores, "Fs": Fs_scores}  # , "Fe": Fe_scores, "Fo": Fo_scores

method_order_by_group = {
    "Fg": ["identity", "esm2", "hyenadna", "genept"],
    "Fe": ["scgpt_bin", "autobin", "continuous"],
    "Fo": ["exprsort", "chromsort"], 
    "Fs": ["weighted"]
}

pretty_name_by_group = {
    "Fg": {
        "identity": "Identity",
        "esm2": "ESM2",
        "hyenadna": "HyenaDNA",
        "genept": "GenePT",
    },
    "Fe": {
        "scgpt_bin":      "Quantitle binning",
        "autobin":        "Autobinning",
        "continuous":    "Continuous"    },
    "Fo": {
        "exprsort":      "Expression sort",
        "chromsort":        "Chromosome sort"   },
    "Fs": {
        "weighted":     "Weighted sample"},
}


In [None]:
def paired_tests_vs_baseline(group_dict, baseline_vals):
    rows = []
    b = np.asarray(baseline_vals, dtype=float)
    for method, vals in group_dict.items():
        x = np.asarray(vals, dtype=float)
        assert x.shape == b.shape, f"Seed counts/order must match for {method}."
        d = x - b  # paired differences

        # tests
        tstat, p_t = ttest_rel(x, b)
        try:
            wstat, p_w = wilcoxon(d)
        except ValueError:
            wstat, p_w = np.nan, 1.0

        # Shapiroâ€“Wilk on differences (normality assumption for paired t-test)
        try:
            W, p_sw = shapiro(d)
        except Exception:
            W, p_sw = np.nan, np.nan

        rows.append({
            "method": method, "p_t": p_t, "p_w": p_w,
            "shapiro_W": W, "shapiro_p": p_sw
        })

    df = pd.DataFrame(rows)

    # Holm correction per group (for t and Wilcoxon p-values)
    for col in ["p_t", "p_w"]:
        rej, p_adj, _, _ = multipletests(df[col].values, method="holm")
        df[col + "_holm"] = p_adj
        df[col + "_rej"]  = rej

    return df


def p_to_stars(p):
    if p < 0.001: return '***'
    if p < 0.01:  return '**'
    if p < 0.05:  return '*'
    return ''


In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
import seaborn as sns

def plot_grouped_mcc(
    groups_dict,
    method_order_by_group=None,
    pretty_name_by_group=None,
    title="Ablations by Module (Fg / Fe / Fo / Fs)",
    save_prefix=None,
    baseline_scores=None,
    baseline_label="Linear baseline",
    baseline_position="bottom",           
    group_order=("Fg", "Fe", "Fo", "Fs"),
    stars_by_key=None,
    xlim=(0.35, 0.65),
    mean_line_label="Baseline mean",      
    xticks=None,
    xtick_step=0.1,
    minor_xtick_step=None
):
    """
    Grouped horizontal bar plot with optional baseline row and stars.
    - Baseline mean line styling and visibility are fixed internally (no public params).
    """

    matplotlib.rcParams['pdf.fonttype'] = 42
    matplotlib.rcParams['ps.fonttype']  = 42

    # Internal styling (since we removed public params)
    _BASELINE_STYLE = {"facecolor": "white", "edgecolor": "black", "linewidth": 0.8}
    _SHOW_MEAN_LINE = True
    _MEAN_LINE_STYLE = {"linestyle": "--", "linewidth": 1.5, "color": "red", "alpha": 1.0}

    present_groups = [g for g in group_order if g in groups_dict]

    group_palette = dict(zip(present_groups, sns.color_palette("Set2", len(present_groups)))) if present_groups else {}

    # per-group method order
    group_methods = {}
    for g in present_groups:
        default = list(groups_dict[g].keys())
        group_methods[g] = method_order_by_group.get(g, default) if method_order_by_group else default

    # baseline stats
    b_mean = b_std = None
    if baseline_scores is not None and len(baseline_scores) > 0:
        b_arr = np.asarray(baseline_scores, dtype=float)
        b_mean = float(np.nanmean(b_arr))
        b_std  = float(np.nanstd(b_arr, ddof=1)) if len(b_arr) > 1 else 0.0

    BAR_H = 0.9
    GAP   = 0.6

    y_positions, y_labels, row_meta = [], [], []   # row_meta: ("__BASELINE__", None) or (group, method)
    y = 0.0

    def add_baseline_row():
        nonlocal y
        if b_mean is not None:
            y_positions.append(y); y_labels.append(baseline_label); row_meta.append(("__BASELINE__", None))
            y += BAR_H
            if baseline_position == "top" and present_groups:
                y += GAP

    if baseline_position == "top":
        add_baseline_row()

    for gi, g in enumerate(present_groups):
        for m in group_methods[g]:
            y_positions.append(y)
            pretty = (pretty_name_by_group or {}).get(g, {}).get(m, m.title())
            y_labels.append(pretty)
            row_meta.append((g, m))
            y += BAR_H
        if gi < len(present_groups) - 1:
            y += GAP

    if baseline_position == "bottom":
        if b_mean is not None and present_groups:
            y += GAP
        add_baseline_row()

    def _disable_clip_errorbar(eb):
        objs = []
        if eb is None:
            return
        if hasattr(eb, "lines") and eb.lines is not None:
            objs.extend(eb.lines if isinstance(eb.lines, (list, tuple)) else [eb.lines])
        if hasattr(eb, "caplines") and eb.caplines is not None:
            objs.extend(eb.caplines)
        if hasattr(eb, "barlinecols") and eb.barlinecols is not None:
            objs.extend(eb.barlinecols)
        for o in objs:
            try: o.set_clip_on(False)
            except Exception: pass

    fig, ax = plt.subplots(figsize=(6, 7), facecolor='white')

    if xlim is not None:
        ax.set_xlim(*xlim)
    xmin, xmax = ax.get_xlim()

    if xticks is not None:
        ax.set_xticks(xticks)
    elif xtick_step is not None:
        ax.xaxis.set_major_locator(mticker.MultipleLocator(xtick_step))
    ax.xaxis.set_major_formatter(mticker.FormatStrFormatter('%.2f'))

    if minor_xtick_step:
        ax.xaxis.set_minor_locator(mticker.MultipleLocator(minor_xtick_step))
    else:
        ax.minorticks_off()

    ax.spines['bottom'].set_visible(True)
    ax.xaxis.set_ticks_position('bottom')
    ax.tick_params(axis='x', which='major', bottom=True, top=False,
                   length=6, width=1, direction='out')
    if minor_xtick_step:
        ax.tick_params(axis='x', which='minor', bottom=True, top=False,
                       length=3, width=0.8, direction='out')

    mean_line_handle = None
    if _SHOW_MEAN_LINE and (b_mean is not None):
        vline = ax.axvline(b_mean, **_MEAN_LINE_STYLE)
        vline.set_clip_on(False)
        mean_line_handle = Line2D([0], [0],
                                  linestyle=_MEAN_LINE_STYLE["linestyle"],
                                  linewidth=_MEAN_LINE_STYLE["linewidth"],
                                  color=_MEAN_LINE_STYLE["color"],
                                  alpha=_MEAN_LINE_STYLE["alpha"],
                                  label=mean_line_label)

    for y_val, meta in zip(y_positions, row_meta):
        g, m = meta
        if g == "__BASELINE__":
            width = max(b_mean - xmin, 0.0)
            ax.barh(y_val, width, left=xmin, height=BAR_H,
                    zorder=2, clip_on=False, **_BASELINE_STYLE)
            if b_std and b_std > 0:
                eb = ax.errorbar(b_mean, y_val, xerr=b_std, fmt='none',
                                 ecolor='black', elinewidth=1.2, capsize=3, zorder=3)
                _disable_clip_errorbar(eb)
            continue

        vals = np.asarray(groups_dict[g][m], dtype=float)
        mean = float(np.nanmean(vals))
        std  = float(np.nanstd(vals, ddof=1)) if len(vals) > 1 else 0.0
        color = group_palette.get(g, 'C0')

        width = max(mean - xmin, 0.0)
        ax.barh(y_val, width, left=xmin, height=BAR_H,
                color=color, edgecolor='black', linewidth=0.6,
                zorder=1, clip_on=False)

        if std > 0:
            eb = ax.errorbar(mean, y_val, xerr=std, fmt='none',
                             ecolor='black', elinewidth=1.2, capsize=3, zorder=3)
            _disable_clip_errorbar(eb)

        # stars (if provided)
        if stars_by_key:
            stars = stars_by_key.get((g, m), '')
            if stars:
                dx = max(0.01 * (xmax - xmin), 1e-3)
                tx = min(mean + dx, xmax - 0.004 * (xmax - xmin))
                ha = 'right' if tx < mean + dx else 'left'
                ax.text(tx, y_val, stars, va='center', ha=ha,
                        fontsize=10, fontweight='bold', color='black',
                        zorder=4, clip_on=False)

    ax.set_yticks(y_positions)
    ax.set_yticklabels(y_labels)
    ax.set_xlabel("MCC")
    ax.set_ylabel("")
    ax.set_title(title)

    handles = [Patch(facecolor=group_palette[g], edgecolor='black', label=g) for g in present_groups]
    if b_mean is not None:
        handles.append(Patch(label=baseline_label,
                             facecolor=_BASELINE_STYLE["facecolor"],
                             edgecolor=_BASELINE_STYLE["edgecolor"]))
    if mean_line_handle is not None:
        handles.append(mean_line_handle)
    if handles:
        ax.legend(handles=handles, title="Groups", bbox_to_anchor=(1.02, 1),
                  loc='upper left', frameon=False)

    ax.margins(y=0)
    if y_positions:
        ax.set_ylim(y_positions[0] - BAR_H/2, y_positions[-1] + BAR_H/2)

    # Keep left spine so y-ticks have an anchor
    ax.spines['left'].set_visible(True)

    # Save
    if save_prefix:
        plt.savefig(f"{save_prefix}.pdf", facecolor='white', bbox_inches='tight', pad_inches=0.02)
    plt.show()


In [None]:
baseline = np.array([0.49752957, 0.50437343, 0.51166874, 0.64673525, 0.60097837, 0.56016135], dtype=float)
per_group_results = {g: paired_tests_vs_baseline(md, baseline) for g, md in scores_by_group.items()}

# Choose the adjusted p-value column powering the stars:
STAR_COL = "p_t_holm"   # or "p_w_holm" if you prefer Wilcoxon
stars_by_key = {(g, r["method"]): p_to_stars(r[STAR_COL])
                for g, df in per_group_results.items() for _, r in df.iterrows()}


plot_grouped_mcc(
    groups_dict=scores_by_group,
    method_order_by_group=method_order_by_group,
    group_order=("Fs", "Fo", "Fe","Fg"),
    pretty_name_by_group=pretty_name_by_group,
    baseline_scores= baseline,
    baseline_label="scBERT-tok",
    baseline_position="top",         
    stars_by_key=stars_by_key,          
    xlim=(0.3, 0.65),
    save_prefix='/path/to/file'                   
)
