In [None]:
### Import Libraries.

import os
import anndata as ad
import pandas as pd
import numpy as np
import scanpy as sc
import seaborn as sns
import matplotlib.pyplot as plt
from scipy import sparse
from scipy.sparse import issparse
from pydeseq2.dds import DeseqDataSet
from pydeseq2.default_inference import DefaultInference
from pydeseq2.ds import DeseqStats

In [None]:
### Load Data.

os.chdir("/folder/")
adata =  ad.read_h5ad("adata.h5ad")

In [None]:
### Subset AnnData.

adata_subset = adata[adata.obs["Status"].isin(["C9ALS", "sALS", "Control"])].copy()
adata_subset =  adata_subset[adata_subset.obs["Enrichment"].isin(["No"])].copy()
adata_subset =  adata_subset[adata_subset.obs["Region"].isin(["Spinal_Cord"])].copy()

In [None]:
### Filter genes expressed in ≥20% of cells in ≥1 Status group.

groups = adata_subset.obs['Status'].unique()
expr_bool = adata_subset.X > 0
if not isinstance(expr_bool, np.ndarray):
    expr_bool = expr_bool.toarray().astype(bool)

keep_genes = []
for i, gene in enumerate(adata_subset.var_names):
    for group in groups:
        cells_in_group = adata_subset.obs['Status'] == group
        frac = expr_bool[cells_in_group.values, i].sum() / cells_in_group.sum()
        if frac >= 0.2:
            keep_genes.append(gene)
            break

adata_filtered = adata_subset[:, keep_genes].copy()
print(f"Original genes: {adata_subset.n_vars}")
print(f"After expression filter: {adata_filtered.n_vars}")

In [None]:
### Keep top 80% most expressed genes.

counts = adata_filtered.layers['counts_RNA'] if "counts_RNA" in adata_filtered.layers else adata_filtered.X
mean_expr = np.array(counts.mean(axis=0)).flatten() if issparse(counts) else counts.mean(axis = 0).flatten()

percentile_cutoff = 20
threshold = np.percentile(mean_expr, percentile_cutoff)
genes_to_keep = adata_filtered.var_names[mean_expr >= threshold]
adata_filtered_dynamic = adata_filtered[:, genes_to_keep].copy()

print(f"Genes before percentile filter: {adata_filtered.n_vars}")
print(f"After top 80% filter: {adata_filtered_dynamic.n_vars}")

In [None]:
### Plot: Fraction of cells expressing each gene.

expr_bool = adata_filtered_dynamic.X > 0
if not isinstance(expr_bool, np.ndarray):
    expr_bool = expr_bool.toarray().astype(bool)
frac_expr = expr_bool.sum(axis = 0) / adata_filtered_dynamic.n_obs

plt.figure(figsize=(6, 4))
plt.hist(frac_expr, bins=30, color='skyblue', edgecolor = 'black')
plt.xlabel("Fraction of cells expressing gene")
plt.ylabel("Number of genes")
plt.title("Distribution of gene expression across cells")
plt.show()

In [None]:
### Plot: Mean expression per gene.

counts = adata_filtered_dynamic.layers['counts_RNA'] if "counts_RNA" in adata_filtered_dynamic.layers else adata_filtered_dynamic.X
mean_expr = np.array(counts.mean(axis = 0)).flatten() if issparse(counts) else counts.mean(axis = 0).flatten()

plt.figure(figsize=(6, 4))
plt.hist(mean_expr, bins = 30, color = 'lightgreen', edgecolor = 'black')
plt.xlabel("Mean expression per gene")
plt.ylabel("Number of genes")
plt.title("Distribution of gene expression levels")
plt.show()

In [None]:
### Create pseudobulk matrix.

count_matrix = adata_filtered_dynamic.layers.get("counts_RNA", adata_filtered_dynamic.X).copy()
metadata = adata_filtered_dynamic.obs.copy()
gene_info = adata_filtered_dynamic.var.copy()

if issparse(count_matrix):
    count_matrix = count_matrix.toarray()

expr_df = pd.DataFrame(count_matrix, index = metadata.index, columns = gene_info.index)
expr_df["Sample_ID"] = metadata["Sample_ID"].astype(str)
pseudobulk = expr_df.groupby("Sample_ID")[gene_info.index].sum()

pseudobulk_metadata = (
    metadata.drop_duplicates(subset = ["Sample_ID"])
    .set_index("Sample_ID")
    .loc[pseudobulk.index]
)

print("✅ Pseudobulk shape:", pseudobulk.shape)
print("✅ Metadata shape:", pseudobulk_metadata.shape)

for col in ["Dataset", "Sex", "Status"]:
    pseudobulk_metadata[col] = pseudobulk_metadata[col].astype(str)

In [None]:
### Differential Expression (DESeq2)

n_cpus = min(20, os.cpu_count())

dds = DeseqDataSet(
    counts = pseudobulk,
    metadata = pseudobulk_metadata,
    design_factors = ["Dataset", "Sex", "Status"],
    refit_cooks = True,
    inference = DefaultInference(n_cpus = n_cpus),
    n_cpus = n_cpus,
)
dds.deseq2()

ds = DeseqStats(dds, contrast = ["Status", "sALS", "Control"], inference = DefaultInference(n_cpus = n_cpus)) #Or C9ALS
ds.summary()

In [None]:
de_results = ds.results_df.copy()
de_results["gene"] = de_results.index
de_results.reset_index(drop = True, inplace = True)

de_results.to_excel("DGE_Analysis.xlsx", index = False)
print(de_results.head())

In [None]:
### Matrix Plot.

marker_genes = ["Feature_A", "Feature_B", "Feature_C"]

status_order = ["C9ALS", "Control", "sALS", "fALS"]
adata.obs["Status"] = pd.Categorical(adata.obs["Status"], categories = status_order, ordered = True)
adata.obs["Status_Rank"] = adata.obs["Status"].cat.codes

sample_order = (
    adata.obs[["Sample_ID", "Status_Rank"]]
    .drop_duplicates("Sample_ID")
    .sort_values("Status_Rank")["Sample_ID"]
    .tolist()
)

matrix = sc.pl.matrixplot(
    adata,
    var_names = marker_genes[::-1],
    groupby = "Sample_ID",
    categories_order = sample_order,
    layer = "log1p_normalized",
    cmap = "Spectral_r",
    standard_scale = "var",
    colorbar_title = "Expression",
    title = "Matrix Plot of Cell Type Marker Genes",
    swap_axes = True,
    return_fig = True,
    show = False
)

for key, ax in matrix.get_axes().items():
    if isinstance(ax, matplotlib.axes.Axes):
        ax.set_facecolor("white")
        for spine in ax.spines.values():
            spine.set_visible(False)
        for label in ax.get_yticklabels():
            label.set_fontweight("bold")
            label.set_color("#3A3A3A")
        ax.set_xlabel("Sample ID", fontsize = 12)
        ax.set_ylabel("Marker Genes", fontsize = 12)
        ax.tick_params(axis = "x", rotation = 90, labelsize = 10)
        ax.tick_params(axis = "y", labelsize = 10)

matrix.fig.set_size_inches(4, 6)
matrix.fig.tight_layout()
matrix.fig.savefig(
    "matrixplot_marker_genes.png",
    dpi = 800,
    bbox_inches = "tight",
    transparent = True
)