In [4]:
import scanpy as sc

In [None]:
adata = sc.read_h5ad('../data/RREAE_5k_raw_integration_processed.h5ad')

In [None]:
adata.layers['raw'].max()

In [None]:
def gene_counts_per_celltype_distance(
    adata,
    gene,
    celltype_col="celltype_merged",
    distance_col="lesion_distance_bin",
    layer=None,
    use_raw=True
):
    """
    For a given gene, compute per (celltype, distance_bin):
      - total counts of the gene
      - number of cells
      - mean counts per cell
    """
    from scipy import sparse
    import numpy as np
    import pandas as pd

    # choose matrix and var_names
    if use_raw:
        if adata.raw is None:
            raise ValueError("use_raw=True but adata.raw is None")
        X = adata.raw.X
        var_names = np.asarray(adata.raw.var_names)
    elif layer is not None:
        X = adata.layers[layer]
        var_names = np.asarray(adata.var_names)
    else:
        X = adata.X
        var_names = np.asarray(adata.var_names)

    # locate gene index
    if gene not in var_names:
        raise ValueError(f"Gene {gene} not found in var_names")
    g_idx = np.where(var_names == gene)[0][0]

    # extract gene counts per cell
    if sparse.issparse(X):
        gene_counts = np.asarray(X[:, g_idx].toarray()).ravel()
    else:
        gene_counts = np.asarray(X[:, g_idx]).ravel()

    # assemble df
    df = adata.obs[[celltype_col, distance_col]].copy()
    df["gene_counts"] = gene_counts

    grouped = (df.groupby([celltype_col, distance_col])
                 .agg(total_gene_counts=("gene_counts", "sum"),
                      n_cells=(celltype_col, "count"))
                 .reset_index())

    grouped["mean_gene_counts_per_cell"] = (
        grouped["total_gene_counts"] / grouped["n_cells"]
    )

    grouped["gene"] = gene
    return grouped

In [None]:
df_slc16a3 = gene_counts_per_celltype_distance(
    adata,
    gene="Slc16a3",
    celltype_col="celltype_merged",
    distance_col="lesion_distance_bin",
    use_raw=False, 
    layer = 'raw'
)

# Just astrocytes
df_slc16a3.query("celltype_merged == 'Astrocyte'")

In [None]:
df_slc16a3 = gene_counts_per_celltype_distance(
    adata,
    gene="Ldha",
    celltype_col="celltype_merged",
    distance_col="lesion_distance_bin",
    use_raw=False, 
    layer = 'raw'
)

# Just astrocytes
df_slc16a3.query("celltype_merged == 'Astrocyte'")

In [None]:
genes = ['Hif1a','Hk2','Pfkl','Pdk1','Pkm','Ldha','Ldhb','Slc16a1','Slc16a3','Serpina3n','Ppargc1a',"Mfn1","Mfn2","Opa1",'Sirt2']


In [None]:
import re, numpy as np, pandas as pd, seaborn as sns, matplotlib.pyplot as plt
from scipy import sparse

# --- helper: compute per-gene counts by (celltype, distance bin)
def gene_counts_per_celltype_distance(adata, gene, celltype_col="celltype_merged",
                                      distance_col="lesion_distance_bin",
                                      use_raw=False, layer="raw"):
    if use_raw:
        if adata.raw is None: raise ValueError("use_raw=True but adata.raw is None")
        X = adata.raw.X; var_names = np.asarray(adata.raw.var_names)
    else:
        X = adata.layers[layer] if layer is not None else adata.X
        var_names = np.asarray(adata.var_names)

    if gene not in var_names:
        return pd.DataFrame(columns=[celltype_col, distance_col, "total_gene_counts", "n_cells",
                                     "mean_gene_counts_per_cell", "gene"])

    g_idx = np.where(var_names == gene)[0][0]
    if sparse.issparse(X):
        # ensure subscriptable & dense col
        X = X.tocsr()
        gene_counts = np.asarray(X[:, g_idx].toarray()).ravel()
    else:
        gene_counts = np.asarray(X[:, g_idx]).ravel()

    df = adata.obs[[celltype_col, distance_col]].copy()
    df["gene_counts"] = gene_counts

    out = (df.groupby([celltype_col, distance_col], dropna=False)
             .agg(total_gene_counts=("gene_counts","sum"),
                  n_cells=(celltype_col,"count"))
             .reset_index())
    out["mean_gene_counts_per_cell"] = out["total_gene_counts"] / out["n_cells"].replace(0, np.nan)
    out["gene"] = gene
    return out

# --- helper: natural sort of distance bins like "0-25 µm", "25-50 µm", ...
def _bin_key(s):
    if s is None or (isinstance(s, float) and np.isnan(s)): return (float("inf"),)
    m = re.search(r"(\d+)", str(s))
    return (int(m.group(1)) if m else float("inf"),)

# --- main: build tidy df for selected genes & celltypes and plot
def plot_distance_dotplot(adata, genes_subset, celltypes_subset=None,
                          celltype_col="celltype_merged", distance_col="lesion_distance_bin",
                          use_raw=False, layer="raw", sizes=(20, 300), cmap="viridis"):
    # gather
    frames = []
    for g in genes_subset:
        frames.append(
            gene_counts_per_celltype_distance(
                adata, g, celltype_col=celltype_col, distance_col=distance_col,
                use_raw=use_raw, layer=layer
            )
        )
    df = pd.concat(frames, ignore_index=True) if frames else pd.DataFrame()

    # subset cell types (optional)
    if celltypes_subset is not None:
        df = df[df[celltype_col].isin(celltypes_subset)].copy()

    # order axes
    # distance bins
    bins = sorted(df[distance_col].dropna().unique(), key=_bin_key)
    df[distance_col] = pd.Categorical(df[distance_col], categories=bins, ordered=True)
    # celltypes (keep provided order if subset given, else alphabetical)
    if celltypes_subset is not None:
        df[celltype_col] = pd.Categorical(df[celltype_col], categories=celltypes_subset, ordered=True)
    else:
        df[celltype_col] = pd.Categorical(df[celltype_col], ordered=True)

    # plot (one column per gene)
    g = sns.relplot(
        data=df,
        x=distance_col, y=celltype_col,
        hue="mean_gene_counts_per_cell",
        size="total_gene_counts",
        col="gene", kind="scatter",
        palette=cmap, sizes=sizes, alpha=0.85, edgecolor="none"
    )
    g.set_xticklabels(rotation=45, ha="right")
    g.set_ylabels("Cell type")
    g.set_xlabels("Lesion distance bin (µm)")
    g.figure.suptitle("Per-cell mean counts (color) & total counts (size) by distance", y=1.02)
    plt.show()
    return df  # return the tidy data used for the plot

# --- usage examples ---
# choose any subset you like:
genes_subset = ["Slc16a1"]
celltypes_subset = ["Astrocyte"]  # or None for all

df_dot = plot_distance_dotplot(
    adata,
    genes_subset=genes_subset,
    celltypes_subset=celltypes_subset,
    celltype_col="celltype_merged",
    distance_col="lesion_distance_bin",
    use_raw=False, layer="raw",   # set use_raw=True if your adata.raw has counts
    sizes=(30, 450),              # tweak size range if needed
    cmap="viridis"
)
# df_dot contains total_gene_counts, n_cells, mean_gene_counts_per_cell per (celltype, bin, gene)

In [None]:
adata.obs.celltype_merged.unique()

In [None]:
genes_subset = ["Ldha"]
celltypes_subset = ["Astrocyte", 'Oligodendrocyte','Microglia','T cell','B cell','Dendritic cell']  # or None for all

df_dot = plot_distance_dotplot(
    adata,
    genes_subset=genes_subset,
    celltypes_subset=celltypes_subset,
    celltype_col="celltype_merged",
    distance_col="lesion_distance_bin",
    use_raw=False, layer="raw",   # set use_raw=True if your adata.raw has counts
    sizes=(30, 450),              # tweak size range if needed
    cmap="viridis"
)

In [None]:
genes_subset = ["Slc16a1"]
celltypes_subset = ["Astrocyte", 'Oligodendrocyte','Microglia','T cell','B cell','Dendritic cell']  # or None for all

df_dot = plot_distance_dotplot(
    adata,
    genes_subset=genes_subset,
    celltypes_subset=celltypes_subset,
    celltype_col="celltype_merged",
    distance_col="lesion_distance_bin",
    use_raw=False, layer="raw",   # set use_raw=True if your adata.raw has counts
    sizes=(30, 450),              # tweak size range if needed
    cmap="viridis"
)

In [None]:
genes_subset = ["Serpina3n"]
celltypes_subset = ["Astrocyte", 'Oligodendrocyte','Microglia','T cell','B cell','Dendritic cell']  # or None for all

df_dot = plot_distance_dotplot(
    adata,
    genes_subset=genes_subset,
    celltypes_subset=celltypes_subset,
    celltype_col="celltype_merged",
    distance_col="lesion_distance_bin",
    use_raw=False, layer="raw",   # set use_raw=True if your adata.raw has counts
    sizes=(30, 450),              # tweak size range if needed
    cmap="viridis"
)

In [None]:
genes = ['Hif1a','Hk2','Pfkl','Pdk1','Pkm','Ldha','Ldhb','Slc16a1','Slc16a3','Serpina3n','Ppargc1a',"Mfn1","Mfn2","Opa1",'Sirt2']


In [None]:
genes = ['Hif1a','Hk2','Pfkl','Pdk1','Pkm','Ldha','Ldhb','Slc16a1','Slc16a3','Serpina3n','Ppargc1a',"Mfn1","Mfn2","Opa1",'Sirt2']

for cell in adata.obs.celltype_merged.unique():
    print(cell)
    # Create dotplot (returns a dict of Figures when var_names is a dict)
    sc.pl.dotplot(
        adata[adata.obs.celltype_merged == cell],
        var_names=genes,
        groupby="course",
        standard_scale="var",
        #dot_max=0.5,
        #dot_min=0.05,
        color_map="Reds",
        dendrogram=False,
        figsize=(5, 3),
        #categories_order=['Ctrl','Early','Peak','Late'],
        title=cell
    )


In [None]:
import scanpy as sc
import matplotlib as mpl
mpl.rcParams['svg.fonttype'] = 'none'  # keep text editable

dp = sc.pl.dotplot(
    adata[adata.obs.model == 'RR'],
    var_names=genes,
    groupby="course",
    standard_scale="var",
    dot_max=0.5, dot_min=0.05, color_map="Reds",
    dendrogram=False,
    categories_order=['PLP CFA','onset I','onset II','peak I','monophasic','remitt I','peak II','remitt II','peak III'],
    figsize=(8,3),
)
#dp.savefig("../../data/plots/bioenergetic_map_RR.svg")

In [None]:
mask = (adata.obs["model"] == "Chronic") & (adata.obs["course"] == "remitt II")
adata.obs.loc[mask, "model"] = "RR"


In [None]:
dp = sc.pl.dotplot(
    adata[adata.obs.model == 'Chronic'],
    var_names=genes,
    groupby="course",
    standard_scale="var",
    dot_max=0.5, dot_min=0.05, color_map="Reds",
    dendrogram=False,
    categories_order=['MOG CFA',
                      'non symptomatic',
                      'early onset',
                      'chronic peak',
                      'chronic long'],
    figsize=(6, 3),
)
#dp.savefig("../../data/plots/bioenergetic_map_Chronic.svg", bbox_inches="tight")

In [None]:
dp = sc.pl.dotplot(
    adata[adata.obs.model == 'Chronic'],
    var_names=genes,
    groupby="course",
    standard_scale="var",
    dot_max=0.5, dot_min=0.05, color_map="Reds",
    dendrogram=False,
    categories_order=['MOG CFA',
                      'non symptomatic',
                      'early onset',
                      'chronic peak',
                      'chronic long'],
    figsize=(10, 4),
)
#dp.savefig("../../data/plots/bioenergetic_map_Chronic.svg", bbox_inches="tight")

In [None]:
ordered_cell_types = [
    'Neuron',                      # Excitatory/inhibitory signal transmission, main targets of neurodegeneration
    'OPC',                         # Oligodendrocyte precursor cells, proliferative, remyelination potential
    'OPC (cycling)',               # Actively dividing OPCs during repair or inflammation
    'Oligodendrocyte',              # Myelinating glia, maintain axonal conduction, metabolically support neurons
    'DA-Oligodendrocyte',          # Disease-associated oligodendrocytes, altered myelin/lipid metabolism
    'Astrocyte',                   # Structural/metabolic support, blood-brain barrier maintenance, neurotransmitter cycling
    'DA-Astrocyte',                # Reactive astrocytes with pro- or anti-inflammatory phenotypes
    'Ependymal',                   # Line ventricles/central canal, regulate cerebrospinal fluid composition
    'Microglia (homeostatic)',     # CNS-resident immune cells in surveillance mode
    'Microglia (intermediate)',    # Transitional activation states, partly inflammatory
    'Microglia (cycling)',         # Proliferating microglia during inflammatory expansion
    'Foamy Microglia',             # Lipid-laden microglia, often in demyelinated lesions
    'Monocyte',                    # Peripheral immune cells infiltrating CNS, non-inflammatory phenotype
    'Monocyte (inflammatory)',     # Infiltrating monocytes with pro-inflammatory transcriptional profile
    'APC/Myeloid',                  # Antigen-presenting myeloid cells, drive adaptive immune activation
    'Foamy Myeloid',               # Lipid-rich infiltrating myeloid cells, linked to chronic lesions
    'Dendritic cell',              # Professional antigen-presenting cells, activate T cells
    'T cell',                      # Adaptive immunity, immune surveillance
    'T cell (cycling)',            # Activated, proliferating T cxells during immune response
    'B cell',                      # Adaptive immunity, antibody production, antigen presentation
    'Endothelial',                 # Vascular lining cells, blood-brain barrier integrity
    'Pericyte',                    # Perivascular support, regulate blood flow and BBB permeability
    'Fibroblast',                  # ECM production, scar formation, meningeal fibrosis
    'Fibroblast (cycling)',        # Actively dividing fibroblasts in fibrosis/repair
    'Mixed glia-vascular',         # Hybrid or transitional population at glia–vascular interfaces
    'Glial-like',                  # Ambiguous glial phenotype, potentially progenitor or transitional
    'Metabolic-like'               # High metabolic activity, unclear lineage or specialized role
]

In [None]:
adata.X

In [None]:
dp = sc.pl.dotplot(
    adata,
    var_names=genes,
    groupby="sub_type_III",
    standard_scale="var",
    dot_max=0.5, dot_min=0.05, color_map="Reds",
    dendrogram=False,
    figsize=(8, 8),
    #vmin=-5,
    vmax=1,
    categories_order = ordered_cell_types,
)
#dp.savefig("../../data/plots/bioenergetic_map_Chronic.svg", bbox_inches="tight")

In [None]:
list(adata.obs.celltype_merged.unique())

In [None]:
merge_map = {
    "Fibroblast (cycling)": "Fibroblast",
    
}

celltype_col = "sub_type_III"
adata.obs["celltype_merged"] = adata.obs[celltype_col].replace(merge_map)


In [None]:
dp = sc.pl.dotplot(
    adata[adata.obs.sub_type_III.str.contains('Fibro')],
    var_names=genes,
    groupby="sub_type_III",
    standard_scale="var",
    dot_max=0.5, dot_min=0.05, color_map="Reds",
    dendrogram=False,
    figsize=(8, 2),
    #vmin=-5,
    vmax=1,
    #categories_order = ordered_cell_types,
)
#dp.savefig("../../data/plots/bioenergetic_map_Chronic.svg", bbox_inches="tight")

In [None]:
list(adata.obs.sub_type_III.unique())

In [None]:
dp = sc.pl.dotplot(
    adata[adata.obs.sub_type_III.str.contains('OPC')],
    var_names=genes,
    groupby="sub_type_III",
    standard_scale="var",
    dot_max=0.5, dot_min=0.05, color_map="Reds",
    dendrogram=False,
    figsize=(8, 2),
    #vmin=-5,
    vmax=1,
    #categories_order = ordered_cell_types,
)
#dp.savefig("../../data/plots/bioenergetic_map_Chronic.svg", bbox_inches="tight")

In [None]:
dp = sc.pl.dotplot(
    adata[adata.obs.sub_type_III.str.contains('Astr')],
    var_names=genes,
    groupby="sub_type_III",
    standard_scale="var",
    dot_max=0.5, dot_min=0.05, color_map="Reds",
    dendrogram=False,
    figsize=(8, 2),
    #vmin=-5,
    vmax=1,
    #categories_order = ordered_cell_types,
)
#dp.savefig("../../data/plots/bioenergetic_map_Chronic.svg", bbox_inches="tight")

In [None]:
fibro = adata[adata.obs["sub_type_III"].isin(["Fibroblast (cycling)", "Fibroblast"])].copy()

sc.tl.rank_genes_groups(
    fibro,
    groupby="sub_type_III",
    reference="Fibroblast",  # baseline group
    method="wilcoxon",
    layer='raw'
)



In [None]:
deg_two = sc.get.rank_genes_groups_df(fibro, group="Fibroblast (cycling)")
deg_two.head()

In [None]:
deg_two[deg_two.names.isin(genes)]