# data loading

In [1]:
import scanpy as sc
import pandas as pd

In [2]:
adata = sc.read_h5ad('../../data/RREAE_5k_raw_integration_processed.h5ad')
adata.obs.index = adata.obs.index.astype(str)
adata.obs_names_make_unique()

  utils.warn_names_duplicates("obs")


In [3]:
anno = pd.read_csv('/Users/christoffer/Downloads/2025_Xenium5k_sample_annotation_integration_TS.csv', sep = ';', )
anno

Unnamed: 0,sample_id,sample_name,region,course,condition,model
0,S1-B1_1,R1-1_T,T,remitt I,EAE,RR
1,S1-B1_0,OS1-1_L,L,onset I,EAE,RR
2,S1-B1_2,P2-1_C,C,peak I,EAE,RR
3,S1-B2_2,P2-1_L,L,peak II,EAE,RR
4,S1-B2_0,OS1_1_T,T,onset I,EAE,RR
...,...,...,...,...,...,...
102,G6_L3_0,872025 TSC,C,chronic long,EAE,Chronic
103,G6_L3_2,872021 CTSC-C,C,MOG CFA,CONTROL,Chronic
104,G4_L1_1,872025 LSC,L,chronic long,EAE,Chronic
105,G4_L1_0,872022 LSC,L,chronic long,EAE,Chronic


In [4]:
for meta in ['sample_name', 'region', 'course','condition','model']:
    mapping_dict = dict(zip(anno['sample_id'], anno[meta]))
    adata.obs[meta] = adata.obs['sample_id'].map(mapping_dict)

In [10]:
adata.layers['raw'] = adata.layers['raw'].astype(int)

In [21]:
adata.X = adata.layers['raw']

In [23]:
del adata.layers['raw']

In [24]:
adata

AnnData object with n_obs × n_vars = 891821 × 5101
    obs: 'x_centroid', 'y_centroid', 'transcript_counts', 'control_probe_counts', 'genomic_control_counts', 'control_codeword_counts', 'unassigned_codeword_counts', 'deprecated_codeword_counts', 'total_counts', 'cell_area', 'nucleus_area', 'nucleus_count', 'segmentation_method', 'run', 'grid_label', 'project', 'strain', 'n_genes_by_counts', 'n_counts', 'n_genes', 'leiden_0.5', 'leiden_1', 'leiden_1.5', 'leiden_2', 'leiden_2.5', 'sample_id', 'sub_type', 'cell_type', 'cell_class', 'area', 'sample_name', 'region', 'course', 'condition', 'model', 'cell_id', 'rbd_domain', 'sub_type_I', 'sub_type_II', 'sub_type_III', 'cell_id_2', 'match_key', 'lesion_density_call', 'lesion_distance_um', 'lesion_distance_bin', 'celltype_merged', 'dist_bin'
    var: 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'mean', 'std'
    uns: 'cell_class_colors', 'cell_polygons', 'cell_type_colors', 'condition_colors', 'course_colors', 'd

In [25]:
adata.write('../../data/RREAE_5k_raw_only_integration_processed.h5ad')

In [9]:
MODELS = {'MOG': {'baseline': 'MOG CFA',
  'courses': ['non symptomatic',
   'early onset',
   'chronic peak',
   'chronic long']},
 'PLP': {'baseline': 'PLP CFA',
  'courses': ['onset I',
   'onset II',
   'peak I',
   'monophasic',
   'remitt I',
   'peak II',
   'remitt II',
   'peak III']}}

genes = ['Hif1a','Hk2','Pfkl','Pdk1','Pkm','Ldha','Ldhb','Slc16a1','Slc16a3','Serpina3n','Ppargc1a',"Mfn1","Mfn2","Opa1",'Sirt2']


# utilities

In [11]:
import numpy as np
import pandas as pd
import scipy.sparse as sp

def _to_dense(X):
    return X.toarray() if sp.issparse(X) else np.asarray(X)

def get_matrix(adata, layer=None):
    if layer is None:
        return adata.X
    if layer in adata.layers:
        return adata.layers[layer]
    raise ValueError(f"Layer '{layer}' not found.")

def compute_logcpm(adata, layer=None):
    X = _to_dense(get_matrix(adata, layer))
    lib = X.sum(axis=1, keepdims=True)
    lib[lib==0] = 1
    return np.log1p((X / lib) * 1e6)

# Function to compute variability

In [16]:
import numpy as np
import pandas as pd
import scipy.sparse as sp

def _to_dense(X):
    return X.toarray() if sp.issparse(X) else np.asarray(X)

def _get_matrix(adata, layer=None):
    if layer is None:
        return adata.X
    if layer in adata.layers:
        return adata.layers[layer]
    raise ValueError(f"Layer '{layer}' not found. Available: {list(adata.layers.keys())}")

def _compute_logcpm(adata, layer=None):
    X = _to_dense(_get_matrix(adata, layer))
    lib = X.sum(axis=1, keepdims=True); lib[lib == 0] = 1
    return np.log1p((X / lib) * 1e6)

def compute_dotplot_stats_by_proximity(
    adata,
    genes=None,                          # None = all genes
    celltype_col="celltype_merged",
    proximity_col="lesion_proximity",
    counts_layer="raw",
    selected_celltypes=None,
    selected_proximities=None,
    min_cells_per_group=0
):
    """
    Returns tidy DataFrame with columns:
      [celltype_merged, lesion_proximity, gene, mean_logcpm, pct_expressing]
    """

    # ---- subset cells
    mask = pd.Series(True, index=adata.obs_names)
    if selected_celltypes is not None:
        mask &= adata.obs[celltype_col].astype(str).isin(selected_celltypes)
    if selected_proximities is not None:
        mask &= adata.obs[proximity_col].astype(str).isin(selected_proximities)
    if mask.sum() == 0:
        raise ValueError("No cells left after filtering cell types / proximity.")

    ad = adata[mask].copy()
    ad.obs_names_make_unique()
    ad.obs[celltype_col]  = ad.obs[celltype_col].astype(str)
    ad.obs[proximity_col] = ad.obs[proximity_col].astype(str)

    # --- choose genes
    if genes is None:
        genes = ad.var_names.tolist()
    else:
        genes = [g for g in genes if g in ad.var_names]
        if not genes:
            raise ValueError("None of the requested genes are in adata.var_names.")

    # per-cell logCPM
    logcpm = _compute_logcpm(ad, layer=counts_layer)
    logcpm_df = pd.DataFrame(logcpm, index=ad.obs_names, columns=ad.var_names)[genes]

    # long + metadata
    long = (
        logcpm_df.stack().rename("logcpm").reset_index()
        .rename(columns={"level_0":"cell_id","level_1":"gene"})
        .merge(
            ad.obs[[celltype_col, proximity_col]].reset_index().rename(columns={"index":"cell_id"}),
            on="cell_id", how="left"
        )
    )

    # drop tiny groups
    if min_cells_per_group and min_cells_per_group > 0:
        sizes = long.groupby([celltype_col, proximity_col]).size()
        keep = sizes[sizes >= min_cells_per_group].index
        long = long.set_index([celltype_col, proximity_col]).loc[keep].reset_index()
        if long.empty:
            raise ValueError("All (celltype×proximity) groups dropped by min_cells_per_group.")

    # aggregate
    stats = (
        long.groupby([celltype_col, proximity_col, "gene"], observed=True)
            .agg(mean_logcpm=("logcpm","mean"),
                 pct_expressing=("logcpm", lambda x: (x > 0).mean()*100.0))
            .reset_index()
    )

    # lock ordering
    stats["gene"] = pd.Categorical(stats["gene"], categories=genes, ordered=True)
    if selected_proximities is not None:
        stats[proximity_col] = pd.Categorical(
            stats[proximity_col],
            categories=selected_proximities,
            ordered=True
        )
    if selected_celltypes is not None:
        stats[celltype_col] = pd.Categorical(
            stats[celltype_col],
            categories=selected_celltypes,
            ordered=True
        )

    return stats.sort_values([celltype_col, proximity_col, "gene"]).reset_index(drop=True)

In [17]:
def variable_genes_over_proximity(
    stats,
    celltype,
    proximity_col="lesion_proximity",
    how=("logcpm","pct"),   # choose one or both; returns both by default
):
    """
    Rank genes by variability across proximity for a single cell type.
    stats must have columns: [celltype_merged, proximity_col, gene, mean_logcpm, pct_expressing]
    Returns a dict with 'logcpm' and/or 'pct' DataFrames sorted by variance.
    """
    df = stats.loc[stats["celltype_merged"].astype(str) == str(celltype)].copy()
    if df.empty:
        raise ValueError(f"No rows for celltype='{celltype}' in stats.")

    # pivot to genes × proximity
    mat_logcpm = df.pivot(index="gene", columns=proximity_col, values="mean_logcpm")
    mat_pct    = df.pivot(index="gene", columns=proximity_col, values="pct_expressing")

    out = {}
    if "logcpm" in how:
        var_log = mat_logcpm.var(axis=1, skipna=True)
        rng_log = (mat_logcpm.max(axis=1) - mat_logcpm.min(axis=1))
        out["logcpm"] = (
            pd.DataFrame({"var": var_log, "range": rng_log})
              .sort_values("var", ascending=False)
        )
    if "pct" in how:
        var_pct = mat_pct.var(axis=1, skipna=True)
        rng_pct = (mat_pct.max(axis=1) - mat_pct.min(axis=1))
        out["pct"] = (
            pd.DataFrame({"var": var_pct, "range": rng_pct})
              .sort_values("var", ascending=False)
        )
    return {"mat_logcpm": mat_logcpm, "mat_pct": mat_pct, "ranks": out}

In [14]:
# define your proximity column + order
proximity_col = "lesion_distance_bin"
proximity_order = ["0–10µm","10–25µm","25–50µm","50–100µm","100–200µm","200–500µm",">500µm"]  # change to your labels/order

# compute stats for the cell types and genes you care about
stats_prox = compute_dotplot_stats_by_proximity(
    adata,
    genes=None,
    celltype_col="celltype_merged",
    proximity_col=proximity_col,
    counts_layer="raw",
    selected_celltypes=["Astrocyte", "Oligodendrocyte"],  # or any set
    selected_proximities=proximity_order,
    min_cells_per_group=20
)

# rank most variable genes over proximity for oligodendrocytes
res = variable_genes_over_proximity(
    stats_prox, celltype="Oligodendrocyte", proximity_col=proximity_col, how=("logcpm","pct")
)
top_log = res["ranks"]["logcpm"].head(15)
top_pct = res["ranks"]["pct"].head(15)
print("Top by variance (mean logCPM):\n", top_log, "\n")
print("Top by variance (% expressing):\n", top_pct)

# quick plot trajectories for the top 5 (mean logCPM)
import matplotlib.pyplot as plt
import seaborn as sns

top5 = top_log.head(10).index.tolist()
palette = sns.color_palette("tab10", n_colors=len(top5))

mat = res["mat_logcpm"].loc[top5, proximity_order]  # aligned to your order
fig, ax = plt.subplots(figsize=(7,4))
for g, col in zip(top5, palette):
    ax.plot(proximity_order, mat.loc[g], marker="o", label=fr"$\it{{{g}}}$", color=col)
ax.set_ylabel("mean logCPM")
ax.set_title("Oligodendrocyte — top variable genes across lesion proximity")
ax.legend(bbox_to_anchor=(1.02,1), loc="upper left")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()


KeyboardInterrupt



# Plot top genes

In [18]:
import matplotlib.pyplot as plt
import seaborn as sns

model = "PLP"
cm, var = variable_genes_per_model(
    adata, model_name=model, MODELS=MODELS,
    celltype="Oligodendrocyte"
)

top_genes = var.head(5).index
palette = sns.color_palette("tab10", n_colors=len(top_genes))

fig, ax = plt.subplots(figsize=(8,5))
for g, col in zip(top_genes, palette):
    ax.plot(cm.columns, cm.loc[g], marker="o", color=col, label=fr"$\it{{{g}}}$")

ax.set_ylabel("mean logCPM")
ax.set_title(f"{model} — Astrocyte: Top variable genes")
ax.legend(bbox_to_anchor=(1.05,1), loc="upper left")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

NameError: name 'variable_genes_per_model' is not defined

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

model = "PLP"
cm, var = variable_genes_per_model(
    adata, model_name=model, MODELS=MODELS,
    celltype="Astrocyte"
)

top_genes = var.head(5).index
palette = sns.color_palette("tab10", n_colors=len(top_genes))

fig, ax = plt.subplots(figsize=(8,5))
for g, col in zip(top_genes, palette):
    ax.plot(cm.columns, cm.loc[g], marker="o", color=col, label=fr"$\it{{{g}}}$")

ax.set_ylabel("mean logCPM")
ax.set_title(f"{model} — Astrocyte: Top variable genes")
ax.legend(bbox_to_anchor=(1.05,1), loc="upper left")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

In [None]:
var_sub = var[var.index.isin(genes)]

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

model = "PLP"
cm, var = variable_genes_per_model(
    adata, model_name=model, MODELS=MODELS,
    celltype="Astrocyte"
)

top_genes = var_sub.head(5).index
palette = sns.color_palette("tab10", n_colors=len(top_genes))

fig, ax = plt.subplots(figsize=(8,5))
for g, col in zip(top_genes, palette):
    ax.plot(cm.columns, cm.loc[g], marker="o", color=col, label=fr"$\it{{{g}}}$")

ax.set_ylabel("mean logCPM")
ax.set_title(f"{model} — Astrocyte: Top variable genes")
ax.legend(bbox_to_anchor=(1.05,1), loc="upper left")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

model = "MOG"
cm, var = variable_genes_per_model(
    adata, model_name=model, MODELS=MODELS,
    celltype="Astrocyte"
)

top_genes = var_sub.head(5).index
palette = sns.color_palette("tab10", n_colors=len(top_genes))

fig, ax = plt.subplots(figsize=(8,5))
for g, col in zip(top_genes, palette):
    ax.plot(cm.columns, cm.loc[g], marker="o", color=col, label=fr"$\it{{{g}}}$")

ax.set_ylabel("mean logCPM")
ax.set_title(f"{model} — Astrocyte: Top variable genes")
ax.legend(bbox_to_anchor=(1.05,1), loc="upper left")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

model = "PLP"
cm, var = variable_genes_per_model(
    adata, model_name=model, MODELS=MODELS,
    celltype="Astrocyte"
)

top_genes = var.head(5).index
palette = sns.color_palette("tab10", n_colors=len(top_genes))

fig, ax = plt.subplots(figsize=(8,5))
for g, col in zip(top_genes, palette):
    ax.plot(cm.columns, cm.loc[g], marker="o", color=col, label=fr"$\it{{{g}}}$")

ax.set_ylabel("mean logCPM")
ax.set_title(f"{model} — Astrocyte: Top variable genes")
ax.legend(bbox_to_anchor=(1.05,1), loc="upper left")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

model = "PLP"
cm, var = variable_genes_per_model(
    adata, model_name=model, MODELS=MODELS,
    celltype="Astrocyte"
)

top_genes = var.head(5).index
palette = sns.color_palette("tab10", n_colors=len(top_genes))

fig, ax = plt.subplots(figsize=(8,5))
for g, col in zip(top_genes, palette):
    ax.plot(cm.columns, cm.loc[g], marker="o", color=col, label=fr"$\it{{{g}}}$")

ax.set_ylabel("mean logCPM")
ax.set_title(f"{model} — Astrocyte: Top variable genes")
ax.legend(bbox_to_anchor=(1.05,1), loc="upper left")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

In [9]:
course_means.columns

course,MOG CFA,PLP CFA,chronic long,chronic peak,early onset,monophasic,non symptomatic,onset I,onset II,peak I,peak II,peak III,remitt I,remitt II
2610035D17Rik,0.760219,0.567049,0.684663,0.702827,0.600435,0.491483,0.641677,0.303612,0.386249,0.503658,0.420378,0.382122,0.484963,0.503577
9630013A20Rik,0.040635,0.034014,0.064848,0.105693,0.032976,0.047670,0.027849,0.026098,0.049895,0.086409,0.072568,0.058821,0.078464,0.009566
A1cf,0.001182,0.004700,0.004409,0.007733,0.003580,0.001271,0.000000,0.004661,0.003151,0.003579,0.002468,0.000658,0.004619,0.002409
A2m,5.090747,3.781237,5.922780,7.305157,5.906674,5.468080,4.578372,4.774786,4.965225,6.397564,5.945568,7.025800,5.428105,4.825909
Aatf,0.765351,0.711719,0.837476,0.785777,0.729061,0.837987,0.645680,0.717868,0.790823,0.968513,0.869408,0.959560,0.804610,0.856146
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
Zswim9,0.242324,0.259676,0.264105,0.247415,0.219694,0.250217,0.235258,0.186908,0.249990,0.231577,0.241121,0.266640,0.257441,0.274791
Zup1,0.520665,0.338158,0.507997,0.442982,0.425951,0.342797,0.479227,0.235218,0.336152,0.374601,0.251204,0.284319,0.313890,0.355295
Zyx,0.724900,0.723223,0.928527,1.042805,0.882195,0.884206,0.577124,0.703261,0.993592,1.086883,0.734526,1.178293,0.839119,0.772139
Zzef1,1.127073,1.026866,1.124169,1.306193,1.169622,0.995753,1.033505,0.760119,1.002104,1.230116,0.939861,1.169605,1.056981,0.979196


In [8]:
variability

Unnamed: 0,var,range
Cd74,4.811892e+00,6.307446
H2-D1,4.285040e+00,5.773047
Serping1,3.851166e+00,6.077000
Serpina3n,3.560114e+00,6.332653
H2-K1,3.211678e+00,5.520392
...,...,...
Myh7,8.202805e-07,0.003130
Tbpl2,7.894877e-07,0.002742
Gsx2,6.465283e-07,0.002480
Hand1,5.580636e-07,0.001907
