# Drosophila Mature Neuronal Subset - CellRank Analysis

In [None]:
# =============================================================================
# Import Necessary Packages
# =============================================================================

In [None]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad
import cellrank as cr
import scvelo as scv
import pickle

In [None]:
# Additional packages for plotting, logging, etc.
import seaborn as sns
import matplotlib.pyplot as plt
import re
import logging
from scipy.stats import median_abs_deviation
import warnings

In [None]:
# =============================================================================
# Initialise Environment Settings
# =============================================================================

In [None]:
# Set random seed for reproducibility
np.random.seed(12345)

In [None]:
# Set working directory
working_directory = "/DataDrives/Drive2/Clifton/R_Projects/2025_Drosophila_scRNAseq_MonoamineSpecification/ANALYSIS/Step_4_CellRank/Step_4.2_Mature_neuronal_subset_analysis"
os.chdir(working_directory)
print(f"Current working directory: {os.getcwd()}")

In [None]:
# Configure Scanpy settings: verbosity, logging header, and figure parameters.
sc.settings.verbosity = 3
sc.logging.print_header()

In [None]:
# Configure cellrank settings and suppress specified warnings
cr.settings.verbosity = 3
warnings.simplefilter("ignore", category=UserWarning)

In [None]:
# Configure scVelo settings
scv.settings.verbosity = 3
scv.settings.set_figure_params("scvelo")

In [None]:
# =============================================================================
# Load Data
# =============================================================================

In [None]:
# List of shared cell types
shared_cell_types = [
    "Acetylcholine",
    "Acetylcholine/GABA",
    "GABA",
    "GABA/Glutamate",
    "GABA/Serotonin",
    "Glutamate",
    "Serotonin",
    "Serotonin/GABA",
    "Monoamine",
    "Monoamine/Acetylcholine",
    "Monoamine/Serotonin",
    "Monoamine/GABA",
    "Immature_neurons",
    "New-born_neurons/Immature_neurons",
    "Unknown_mature_neurons",
    "Neuroblasts",
    "Neuroblasts/GMCs",
    "Neuroblasts/GMCs/Immature_neurons",
    "GMCs",
    "GMCs/New-born_neurons/Immature_neurons",
    "Neuroblasts/GMCs/New-born_neurons/Immature_neurons",
    "Unknown"
]

# Dictionary mapping cell types to colors
shared_color_palette = {
    # Acetylcholine-related
    "Acetylcholine": "#FFD700",  # vivid gold
    "Acetylcholine/GABA": "#FFEE58",  # sunflower yellow

    # GABA-related
    "GABA": "#B71C1C",  # dark red
    "GABA/Glutamate": "#D84315",  # burnt orange-red
    "GABA/Serotonin": "#F06292",  # deep pink
    "Serotonin/GABA": "#E91E63",  # strong pink-rose
    "Monoamine/GABA": "#EF5350",  # soft red-pink

    # Glutamate-related
    "Glutamate": "#43A047",  # strong green

    # Serotonin-related
    "Serotonin": "#8E24AA",  # deep purple

    # Monoamine-related
    "Monoamine": "#FB8C00",  # bright orange
    "Monoamine/Acetylcholine": "#FFA726",  # soft orange
    "Monoamine/Serotonin": "#FF7043",  # orange-coral

    # Neuroblasts & GMCs
    "Neuroblasts": "#1565C0",  # cobalt blue
    "Neuroblasts/GMCs": "#1E88E5",  # vivid blue
    "Neuroblasts/GMCs/Immature_neurons": "#64B5F6",  # sky blue
    "GMCs": "#0D47A1",  # navy
    "GMCs/New-born_neurons/Immature_neurons": "#1976D2",  # medium blue
    "Neuroblasts/GMCs/New-born_neurons/Immature_neurons": "#90CAF9",  # pale blue

    # Developmental/Immature/Unknown
    "Immature_neurons": "#29B6F6",  # bright cyan
    "New-born_neurons/Immature_neurons": "#4DD0E1",  # teal
    "Unknown_mature_neurons": "#757575",  # neutral gray
    "Unknown": "#BDBDBD"  # light gray
}

In [None]:
# Load the processed AnnData object (scVelo-processed subset)
adata = sc.read_h5ad("4_Velocity_scVelo_processed_subset.h5ad")

In [None]:
adata

In [None]:
# =============================================================================
# Parse and Order Timepoint Data
# =============================================================================

In [None]:
def parse_timepoint(tp):
    """
    Convert a timepoint string to a numerical value.
    
    Expects strings like "hrs_00_03" and returns the midpoint,
    e.g., (0 + 3)/2 = 1.5.
    
    Parameters:
        tp (str): A timepoint string.
        
    Returns:
        float: The computed midpoint or NaN if conversion fails.
    """
    # Remove the "hrs_" prefix if present.
    if tp.startswith("hrs_"):
        tp = tp[len("hrs_"):]
    # Split the remaining string by underscore.
    parts = tp.split("_")
    try:
        # Convert each part to float.
        nums = [float(part) for part in parts]
    except ValueError:
        return np.nan  # Return NaN if conversion fails
    # Compute and return the mean of the numbers (midpoint)
    return np.mean(nums) if len(nums) > 0 else np.nan

In [None]:
# Apply the parse_timepoint function to create a new numerical column.
adata.obs["timepoint_numerical"] = adata.obs["timepoint"].apply(parse_timepoint)

In [None]:
# Group by the original timepoint and determine the first numerical value per group.
# Then sort the timepoints according to their numerical midpoints.
ordered_categories = (
    adata.obs.groupby("timepoint")["timepoint_numerical"]
    .first()
    .sort_values()
    .index
    .tolist()
)

In [None]:
# Update the "timepoint" column to be an ordered categorical variable.
adata.obs["timepoint"] = pd.Categorical(
    adata.obs["timepoint"],
    categories=ordered_categories,
    ordered=True
)

In [None]:
# Print the ordered timepoints and inspect a few numerical values.
print("Ordered timepoint categories:", ordered_categories)
print(adata.obs["timepoint_numerical"].head())
print(adata.obs[["timepoint", "timepoint_numerical"]].drop_duplicates())

In [None]:
# =============================================================================
# Save the Updated AnnData Object
# =============================================================================

In [None]:
# Write the updated AnnData object to file for subsequent analyses.
adata.write("1_CellRank_Starting_Data_subset.h5ad")

In [None]:
# (Reload the object to ensure the saved version is used)
adata = sc.read_h5ad("1_CellRank_Starting_Data_subset.h5ad")

In [None]:
# =============================================================================
# UMAP Visualizations
# =============================================================================

In [None]:
# Plot the UMAP embedding colored by neuronal_annotation_fine.
sc.pl.embedding(
    adata, 
    basis='X_umap', 
    color="neuronal_annotation_fine", 
    legend_loc='right margin', 
    size=80, 
    save='CR_mature_neuronal_subset_annotation.svg'
)
sc.pl.embedding(
    adata, 
    basis='X_umap', 
    color="neuronal_annotation_fine", 
    legend_loc='right margin', 
    size=80, 
    save='CR_mature_neuronal_subset_annotation.pdf'
)

In [None]:
# Plot the UMAP embedding colored by latent time.
sc.pl.embedding(
    adata, 
    basis='X_umap', 
    color="latent_time", 
    size=80, 
    save='CR_mature_neuronal_subset_latent_time.svg'
)
sc.pl.embedding(
    adata, 
    basis='X_umap', 
    color="latent_time", 
    size=80, 
    save='CR_mature_neuronal_subset_latent_time.pdf'
)

In [None]:
# =============================================================================
# Connectivity Kernel Computations
# =============================================================================

In [None]:
adata = sc.read_h5ad("1_CellRank_Starting_Data_subset.h5ad")

In [None]:
# Create and compute a ConnectivityKernel based on cell-cell connectivity.
from cellrank.kernels import ConnectivityKernel
ck = ConnectivityKernel(adata)
ck = ck.compute_transition_matrix()

In [None]:
# Plot random walks using the ConnectivityKernel.
ck.plot_random_walks(
    seed=0,
    n_sims=200,
    start_ixs={"timepoint": "hrs_00_03"},
    basis="X_umap",
    legend_loc="right",
    dpi=300,
    save='CR_mature_neuronal_subset_ConnectivityKernel_RandWalks.svg'
)

In [None]:
# Write the connectivity kernel into the AnnData object and save.
ck.write_to_adata()
adata.write("MNsubset_ConnectvityKernel.h5ad")

In [None]:
# Reload the object and instantiate the kernel from it.
adata = sc.read("MNsubset_ConnectvityKernel.h5ad")
ck = cr.kernels.ConnectivityKernel.from_adata(adata, key="T_fwd")
print(ck)

In [None]:
# =============================================================================
# Velocity Kernel Computations
# =============================================================================

In [None]:
adata = sc.read_h5ad("1_CellRank_Starting_Data_subset.h5ad")

In [None]:
# Compute a VelocityKernel from the AnnData object.
vk = cr.kernels.VelocityKernel(adata)
vk = vk.compute_transition_matrix(model='deterministic', show_progress_bar=True)

In [None]:
# Plot the VelocityKernel projection colored by latent time.
vk.plot_projection(
    basis='X_umap', 
    color='latent_time', 
    color_map='gnuplot', 
    size=80, 
    dpi=300,
    save='CR_mature_neuronal_subset_VelocityKernel_latenttime.svg'
)
vk.plot_projection(
    basis='X_umap', 
    color='latent_time', 
    color_map='gnuplot', 
    size=80, 
    dpi=300,
    save='CR_mature_neuronal_subset_VelocityKernel_latenttime.pdf'
)

In [None]:
# Plot random walks based on the VelocityKernel.
vk.plot_random_walks(
    seed=0,
    n_sims=200,
    start_ixs={"timepoint": "hrs_00_03"},
    basis="X_umap",
    legend_loc="right",
    dpi=300,
    save='CR_mature_neuronal_subset_VelocityKernel_RandWalks.svg'
)

In [None]:
# Write the VelocityKernel to AnnData and save.
vk.write_to_adata()
adata.write("MNsubset_VelocityKernel.h5ad")
adata = sc.read("MNsubset_VelocityKernel.h5ad")
vk = cr.kernels.VelocityKernel.from_adata(adata, key="T_fwd")
print(vk)

In [None]:
# =============================================================================
# CytoTRACE Kernel Computations
# =============================================================================

In [None]:
adata = sc.read_h5ad("1_CellRank_Starting_Data_subset.h5ad")

In [None]:
adata

In [None]:
# Compute a CytoTRACEKernel.
from cellrank.kernels import CytoTRACEKernel
ctk = CytoTRACEKernel(adata)
ctk = ctk.compute_cytotrace()

In [None]:
# Plot embedding colored by CytoTRACE pseudotime and timepoint.
sc.pl.embedding(
    adata,
    color=["ct_pseudotime", "timepoint"],
    basis="X_umap",
    color_map="gnuplot2",
    save='CR_mature_neuronal_subset_CytoTRACEKerneltimepoint.svg'
)

In [None]:
# Violin plot for CytoTRACE pseudotime across timepoints.
sc.pl.violin(
    adata, 
    keys=["ct_pseudotime"], 
    groupby="timepoint", 
    rotation=90,
    save='CR_mature_neuronal_subset_vlnplt_CytoTRACEKerneltimepoint.svg'
)

In [None]:
# Compute transition matrix with a soft threshold using CytoTRACEKernel.
ctk.compute_transition_matrix(threshold_scheme="soft", nu=0.5)

In [None]:
# Plot CytoTRACEKernel projection (colored by timepoint).
ctk.plot_projection(
    basis="X_umap", 
    color="timepoint", 
    legend_loc="right",
    size=80, 
    dpi=300, 
    save='CR_mature_neuronal_subset_CytoTRACEKernel_timepoint.svg'
)

In [None]:
# Plot random walks based on the CytoTRACEKernel.
ctk.plot_random_walks(
    seed=0,
    n_sims=200,
    start_ixs={"timepoint": "hrs_00_03"},
    basis="X_umap",
    legend_loc="right",
    dpi=300,
    save='CR_mature_neuronal_subset_CytoTRACEKernel_RandWalks.svg'
)

In [None]:
# Write the CytoTRACE kernel to AnnData and save.
ctk.write_to_adata()
adata.write("MNsubset_CytoTRACEKernel.h5ad")
adata = sc.read("MNsubset_CytoTRACEKernel.h5ad")
ctk = cr.kernels.CytoTRACEKernel.from_adata(adata, key="T_fwd")
print(ctk)

In [None]:
# =============================================================================
# RealTime Kernel Computations
# =============================================================================

In [None]:
# Reload the original starting AnnData object.
adata = sc.read_h5ad("1_CellRank_Starting_Data_subset.h5ad")

In [None]:
# Visualize embedding colored by numerical timepoint and neuronal annotation.
sc.pl.embedding(
    adata,
    basis="X_umap",
    color=["timepoint_numerical", "neuronal_annotation_fine"],
    color_map="gnuplot",
)

In [None]:
# Prepare and solve the temporal problem.
tp = TemporalProblem(adata)
tp = tp.prepare(time_key="timepoint")
tp = tp.solve(epsilon=1e-3, tau_a=0.95, scale_cost="mean")

In [None]:
# Create a RealTimeKernel from the temporal problem and compute its transition matrix.
from cellrank.kernels import RealTimeKernel
tmk = RealTimeKernel.from_moscot(tp)
tmk.compute_transition_matrix(self_transitions="all", conn_weight=0.2, threshold="auto")

In [None]:
# Plot random walks based on the RealTimeKernel.
tmk.plot_random_walks(
    seed=0,
    n_sims=200,
    start_ixs={"timepoint": "hrs_00_03"},
    basis="X_umap",
    legend_loc="right",
    dpi=300,
    save='CR_mature_neuronal_subset_CytoTRACEKernel_RandWalks.svg'
)

In [None]:
# Write the RealTimeKernel to AnnData and save.
tmk.write_to_adata()
adata.write("MNsubset_RealTimeKernel.h5ad")
adata = sc.read("MNsubset_RealTimeKernel.h5ad")
tmk = cr.kernels.RealTimeKernel.from_adata(adata, key="T_fwd")
print(tmk)

In [None]:
# =============================================================================
# Combine Kernels and GPCCA Analysis
# =============================================================================

In [None]:
# Combine the computed kernels using weighted summation.
# Here, weights are assigned as: VelocityKernel (0.65), CytoTRACEKernel (0.15), and ConnectivityKernel (0.2)
adata = sc.read("MNsubset_ConnectvityKernel.h5ad")
ck = cr.kernels.ConnectivityKernel.from_adata(adata, key="T_fwd")

In [None]:
adata = sc.read("MNsubset_VelocityKernel.h5ad")
vk = cr.kernels.VelocityKernel.from_adata(adata, key="T_fwd")

In [None]:
adata = sc.read("MNsubset_CytoTRACEKernel.h5ad")
ctk = cr.kernels.CytoTRACEKernel.from_adata(adata, key="T_fwd")

In [None]:
adata = sc.read_h5ad("1_CellRank_Starting_Data_subset.h5ad")
combined_kernel = 0.75 * vk + 0.15 * ctk + 0.1 * ck
print(combined_kernel)

In [None]:
# Create a GPCCA estimator from the combined kernel.
g = cr.estimators.GPCCA(combined_kernel)
print(g)

In [None]:
# Compute the Schur decomposition with 30 components and plot the spectrum.
g.compute_schur(n_components=100)
g.plot_spectrum(real_only=True, dpi=300, save='CR_mature_neuronal_subset_CombKer_SchurSpectrum.svg', figsize=(28,6))
g.plot_spectrum(real_only=True, dpi=300, save='CR_mature_neuronal_subset_CombKer_SchurSpectrum.pdf', figsize=(28,6))

In [None]:
print(adata.obs['neuronal_annotation_fine'])

In [None]:
# Compute macrostates (24 states) based on the "neuronal_annotation_fine" grouping.
g.compute_macrostates(n_states=30, cluster_key="neuronal_annotation_fine")
g.plot_macrostates(which="all", legend_loc="right", s=100, dpi=300, save='CR_mature_neuronal_subset_CombKer_Macrostates.svg')
g.plot_macrostates(which="all", legend_loc="right", s=100, dpi=300, save='CR_mature_neuronal_subset_CombKer_Macrostates.pdf')

In [None]:
# Plot the composition of macrostates with respect to "neuronal_annotation_fine".
g.plot_macrostate_composition(key="neuronal_annotation_fine", figsize=(7, 4), dpi=300, save='CR_mature_neuronal_subset_CombKer_MacrostatesComp.svg')

g.plot_macrostate_composition(key="neuronal_annotation_fine", figsize=(7, 4), dpi=300, save='CR_mature_neuronal_subset_CombKer_MacrostatesComp.pdf')

In [None]:
# Plot the coarse transition matrix of the macrostates.
g.plot_coarse_T(annotate=False, dpi=300, save='CR_mature_neuronal_subset_CombKer_CoarseT.svg')
g.plot_coarse_T(annotate=False, dpi=300, save='CR_mature_neuronal_subset_CombKer_CoarseT.pdf')

In [None]:
g.plot_macrostates(which="all", legend_loc="right", s=100, dpi=300, save='CR_mature_neuronal_subset_CombKer_Macrostates.pdf')
g.plot_macrostates(which="all", legend_loc="right", s=100, dpi=300, save='CR_mature_neuronal_subset_CombKer_Macrostates.svg')

In [None]:
macrostate_labels = g.macrostates

# Get unique macrostates and convert to list
unique_macrostates = macrostate_labels.unique().tolist()

# Print all macrostates
print("Unique macrostates:", unique_macrostates)

# Assuming `adata` is your AnnData object
adata = g.adata.copy()
adata.obs["macrostates"] = g.macrostates.astype(str)

# Plot the first 15 macrostates
subset = adata[adata.obs["macrostates"].isin(g.macrostates.unique()[:15])]
sc.pl.umap(subset, color="macrostates",  save="_macrostates_part1.pdf", show=False, size=100)

# Plot the remaining 15 macrostates
subset = adata[adata.obs["macrostates"].isin(g.macrostates.unique()[15:])]
sc.pl.umap(subset, color="macrostates", save="_macrostates_part2.pdf", show=False, size=100)


In [None]:
# =============================================================================
# Set Initial and Terminal States and Visualize Macrostates
# =============================================================================

In [None]:
# Predict initial states and then manually set the desired initial states.
g.predict_initial_states()
g.set_initial_states(states=['Neuroblasts/GMCs_1',  'Neuroblasts/GMCs_2', 'Neuroblasts/GMCs_3', 'Neuroblasts/GMCs_4', 'Neuroblasts/GMCs_5', 'Neuroblasts/GMCs_6'])

g.plot_macrostates(which="initial", legend_loc="right", s=100, dpi=300, save='CR_mature_neuronal_subset_CombKer_Initialstates.pdf')
g.plot_macrostates(which="initial", legend_loc="right", s=100, dpi=300, save='CR_mature_neuronal_subset_CombKer_Initialstates.svg')


In [None]:
# Predict terminal states and set the desired terminal states.
g.set_terminal_states(states=['Glutamate', 'Acetylcholine', 'Monoamine', 'GABA_1', 'GABA_2'], allow_overlap=True)
g.plot_macrostates(which="terminal", legend_loc="right", s=100, dpi=300, save='CR_mature_neuronal_subset_CombKer_Terminalstates.pdf')
g.plot_macrostates(which="terminal", legend_loc="right", s=100, dpi=300, save='CR_mature_neuronal_subset_CombKer_Terminalstates.svg')

In [None]:
# =============================================================================
# Save the GPCCA ("g") Object
# =============================================================================

In [None]:
# Save the estimator object to disk using pickle.
with open("CR_mature_neuronal_subset_CombKer_GPCCA.pkl", "wb") as f:
    pickle.dump(g, f)

print("GPCCA estimator object saved successfully.")

In [None]:
with open("CR_mature_neuronal_subset_CombKer_GPCCA.pkl", "rb") as f:
     g = pickle.load(f)

In [None]:
# Compute fate probabilities and plot the results.
g.compute_fate_probabilities(n_jobs=1, show_progress_bar=True)

In [None]:
g.plot_fate_probabilities(same_plot=False, dpi=300, save='CR_mature_neuronal_subseture_neuronal_subset_CombKer_FateProbabilites.pdf')

g.plot_fate_probabilities(same_plot=False, dpi=300, save='CR_mature_neuronal_subset_CombKer_FateProbabilites.svg')

In [None]:
# =============================================================================
# Circular Projection of Macrostates
# =============================================================================

In [None]:
# Create a new AnnData object from the GPCCA estimator.
adata = g.to_adata(keep='all', copy=True)

In [None]:
# Now generate the circular projection using the updated adata.
cr.pl.circular_projection(
    adata, 
    keys=["neuronal_annotation_fine"], 
    legend_loc="right", 
    dpi=300, 
    save='CR_mature_neuronal_subset_CombKer_CircularProjection.pdf'
)

In [None]:
# Now generate the circular projection using the updated adata.
cr.pl.circular_projection(
    adata, 
    keys=["neuronal_annotation_fine"], 
    legend_loc="right", 
    dpi=300, 
    save='CR_mature_neuronal_subset_CombKer_CircularProjection.svg'
)

In [None]:
sc.pl.violin(adata, keys=["velocity_pseudotime"], groupby="neuronal_annotation_fine", rotation=90, save='mature_neuronal_subset_VlnPlt_velocityseudotime.pdf')

sc.pl.violin(adata, keys=["velocity_pseudotime"], groupby="neuronal_annotation_fine", rotation=90, save='mature_neuronal_subset_VlnPlt_velocityseudotime.svg')

In [None]:
# =============================================================================
# Trajectories
# =============================================================================

In [None]:
# Parameters
top_n = 10

lineages = ['Glutamate', 'Acetylcholine', 'Monoamine', 'GABA_1', 'GABA_2']
cluster_key = "neuronal_annotation_fine"
early_states = ['Neuroblasts/GMCs']

In [None]:
# Ensure mean expression exists for coloring
adata.var["mean expression"] = adata.X.A.mean(axis=0)

In [None]:
# Plot early states on UMAP
sc.pl.embedding(
    adata, basis="umap", color=cluster_key, groups=early_states,
    legend_loc="right", save='CR_mature_neuronal_subset_CombKer_EarlyStates.pdf'
)
sc.pl.embedding(
    adata, basis="umap", color=cluster_key, groups=early_states,
    legend_loc="right", save='CR_mature_neuronal_subset_CombKer_EarlyStates.svg'
)

In [None]:
# Initialize containers
top_genes_dict = {}
combined_df = []

In [None]:
# =============================
# Compute & process drivers per lineage
# =============================

In [None]:
for lineage in lineages:
    print(f"\nProcessing lineage: {lineage}")
    cluster_lineage = lineage.split('_')[0]  # Remove _1/_2 suffix
    clusters = early_states + [cluster_lineage]
    print(f"  Using clusters: {clusters}")
    # Compute drivers for this lineage
    df_lineage = g.compute_lineage_drivers(
        lineages=[lineage],
        cluster_key=cluster_key,
        clusters=clusters
    )
    # Add lineage column if needed
    df_lineage["lineage"] = lineage
    # Save full table
    safe_lineage = lineage.replace("/", "_")  # Replace slashes with underscores
    df_lineage.to_csv(f"{safe_lineage}_drivers.csv", index=True)
    combined_df.append(df_lineage)
    # Extract top genes by qval
    qval_col = f"{lineage}_qval"
    df_lineage = df_lineage.dropna(subset=[qval_col])
    top_genes = (
        df_lineage.sort_values(qval_col, ascending=True)
        .head(top_n)
        .index.tolist()  # gene names are in index
    )
    top_genes_dict[lineage] = top_genes
    # Add fate probs to .obs
    fate_key = f"fate_probabilities_{lineage.lower()}"
    adata.obs[fate_key] = g.fate_probabilities[lineage].X.flatten()
    # UMAP plot: fate prob + top genes
    safe_lineage = lineage.replace("/", "_")  # Replace slashes with underscores
    sc.pl.embedding(
        adata, basis="X_umap", color=[fate_key] + top_genes,
        color_map="viridis", s=50, ncols=3, vmax="p96",
        save=f"FateProbs_{safe_lineage}.pdf"
    )
    sc.pl.embedding(
        adata, basis="X_umap", color=[fate_key] + top_genes,
        color_map="viridis", s=50, ncols=3, vmax="p96",
        save=f"FateProbs_{safe_lineage}.png"
    )

    # Violin plots
    #cr.pl.aggregate_fate_probabilities(
    #    adata, mode="violin", lineages=[lineage],
    #    cluster_key=cluster_key,
    #    save=f"AggregateFateProb_{lineage}.pdf"
    #)
    #cr.pl.aggregate_fate_probabilities(
    #    adata, mode="violin", lineages=[lineage],
    #    cluster_key=cluster_key,
    #    save=f"AggregateFateProb_{lineage}.svg"
    #)


In [None]:
# =============================
# Save combined driver table
# =============================

In [None]:
combined_df = pd.concat(combined_df, axis=0)
combined_df.to_csv("combined_lineage_drivers.csv")

print("All lineage driver analyses completed and saved.")

In [None]:
# =============================
# Ensure mean expression exists for coloring
# =============================

In [None]:
if hasattr(adata.X, "toarray"):  # sparse matrix
    adata.var["mean expression"] = np.array(adata.X.mean(axis=0)).flatten()
else:  # dense matrix
    adata.var["mean expression"] = adata.X.mean(axis=0)

In [None]:
# =============================
# Compute all drivers at once
# =============================

In [None]:
driver_df = g.compute_lineage_drivers(lineages=lineages)
adata = g.to_adata(keep='all', copy=True)

print("varm:", adata.varm.keys())
# Manually need to add this adata.varm (unsure why it doesn't do it automatically)
adata.varm["terminal_lineage_drivers"] = driver_df.to_records(index=True)
print("varm:", adata.varm.keys())

In [None]:
# =============================
# Extract top N genes per lineage based on qval
# =============================

In [None]:
top_genes_dict = {}

In [None]:
for lineage in lineages:
    qval_col = f"{lineage}_qval"
    # Drop NaNs in q-value column
    filtered_df = driver_df.dropna(subset=[qval_col])
    # Sort by ascending q-value and select top genes
    top_genes = (
        filtered_df.sort_values(qval_col, ascending=True)
        .head(top_n)
        .index.tolist()
    )
    top_genes_dict[lineage] = top_genes

In [None]:
# =============================
# Plot driver correlation between all lineage pairs
# =============================

In [None]:
for i in range(len(lineages)):
    for j in range(i + 1, len(lineages)):
        lin_x = lineages[i]
        lin_y = lineages[j]
        print(f"Plotting correlation between {lin_x} and {lin_y}...")
        g.plot_lineage_drivers_correlation(
            lineage_x=lin_x,
            lineage_y=lin_y,
            adjust_text=True,
            gene_sets={
                lin_x: top_genes_dict[lin_x],
                lin_y: top_genes_dict[lin_y],
            },
            color="mean expression",  # Now guaranteed to exist in adata.var
            legend_loc="none",
            figsize=(6, 6),
            dpi=300,
            fontsize=9,
            size=50,
            save=f"DriverCorr_{lin_x}_vs_{lin_y}.png"
        )


In [None]:
# =============================
# Latent time violin plot
# =============================

In [None]:
lt_order = (
    adata.obs.groupby(cluster_key)["latent_time"]
    .mean().sort_values().index.tolist()
)
adata.obs[cluster_key] = pd.Categorical(adata.obs[cluster_key], categories=lt_order, ordered=True)

sc.pl.violin(
    adata, keys=["latent_time"], groupby=cluster_key,
    rotation=90, save='LatentTimeVln.pdf'
)
sc.pl.violin(
    adata, keys=["latent_time"], groupby=cluster_key,
    rotation=90, save='LatentTimeVln.svg'
)

In [None]:
# =============================
# Fit GAMR model for trends
# =============================

In [None]:
# Save the estimator object to disk using pickle.
with open("CR_mature_neuronal_subset_CombKer_GPCCA_2.pkl", "wb") as f:
    pickle.dump(g, f)

print("GPCCA estimator object saved successfully.")

In [None]:
with open("CR_mature_neuronal_subset_CombKer_GPCCA_2.pkl", "rb") as f:
     g = pickle.load(f)

In [None]:
adata = g.to_adata(keep='all', copy=True)
model = cr.models.GAMR(adata, n_knots=6)


In [None]:
# =============================
# Plot gene trends and heatmaps for each lineage
# =============================

In [None]:
# assume `lineages` is the list of lineage keys you want to plot
n_lin = len(lineages)
print(lineages)
color_list = [
    "#B71C1C", 
    "#FB8C00",
    "#B71C1F",
    "#FFD700", 
    "#D84315",
    "#F06292",
    "#43A047"
]
# assign your list of hex‐colors into uns
adata.uns["term_states_fwd_colors"] = color_list

In [None]:
for lineage in lineages:
    print(f"\nPlotting gene trends for: {lineage}")
    cluster_lineage = lineage.split('_')[0]  # Remove _1/_2 suffix
    # Use qval column instead of 'lineage' column
    qval_col = f"{lineage}_qval"
    if qval_col not in driver_df.columns:
        print(f"  Skipping {lineage} — no qval column found.")
        continue
    # Filter and sort top 40 drivers
    df_lineage = driver_df.dropna(subset=[qval_col])
    genes = df_lineage.sort_values(qval_col).head(40).index.tolist()
    # Gene trend plot
    cr.pl.gene_trends(
        adata,
        model=model,
        genes=genes[:8],
        same_plot=True,
        ncols=2,
        time_key="latent_time",
        hide_cells=True,
        weight_threshold=(1e-3, 1e-3),
        save=f"GeneTrends_TopDrivers_{lineage}.png"
    )
    # Heatmap
    cr.pl.heatmap(
        adata,
        model=model,
        lineages=lineage,
        cluster_key=cluster_key,
        show_fate_probabilities=True,
        genes=genes,
        time_key="latent_time",
        figsize=(12, 10),
        show_all_genes=True,
        weight_threshold=(1e-3, 1e-3),
        save=f"HeatmapModelGeneTrends_{lineage}.png"
    )
    
# (Optional) You can inspect if the key exists now:
print("Keys in adata.varm:", adata.varm.keys())
print("Keys in adata.var:", adata.var.keys())


In [None]:
terminal_state_names = g.terminal_states.unique().tolist()
print(terminal_state_names)


In [None]:

# List of genes
genes = ["fd59A", "dmrt99B", "CG4328", "Vsx1", "Vsx2", "Lmx1a", "Atf3", "CG3104"]

# Create output folder for individual gene plots
output_folder = "GeneTrends_TFs"
os.makedirs(output_folder, exist_ok=True)

# Plot each gene individually
for gene in genes:
    cr.pl.gene_trends(
        adata,
        model=model,
        genes=[gene],
        same_plot=True,
        ncols=1,
        time_key="latent_time",
        hide_cells=True,
        weight_threshold=(1e-3, 1e-3),
        n_jobs=1,
        save=f"{output_folder}/GeneTrend_{gene}.pdf"
    )

# Combined plot (all genes together)
cr.pl.gene_trends(
    adata,
    model=model,
    genes=genes,
    same_plot=True,
    ncols=2,
    time_key="latent_time",
    hide_cells=True,
    weight_threshold=(1e-3, 1e-3),
    n_jobs=1,
    save="CR_mature_neuronal_subset_CombKer_ModelGeneTrendsTFs_Monoamine.pdf"
)

In [None]:

# List of genes
genes = ["cic", "NfI", "cbt", "CG4328", "Lmx1a", "FoxP", "trh", "Eip78C", "net", "Pdp1", "luna"]

# Create output folder for individual gene plots
output_folder = "GeneTrends_TFs2"
os.makedirs(output_folder, exist_ok=True)

# Plot each gene individually
for gene in genes:
    cr.pl.gene_trends(
        adata,
        model=model,
        genes=[gene],
        same_plot=True,
        ncols=1,
        time_key="latent_time",
        hide_cells=True,
        weight_threshold=(1e-3, 1e-3),
        n_jobs=1,
        save=f"{output_folder}/GeneTrend_{gene}.pdf"
    )

# Combined plot (all genes together)
cr.pl.gene_trends(
    adata,
    model=model,
    genes=genes,
    same_plot=True,
    ncols=2,
    time_key="latent_time",
    hide_cells=True,
    weight_threshold=(1e-3, 1e-3),
    n_jobs=1,
    save="CR_mature_neuronal_subset_CombKer_ModelGeneTrendsTFs2_Monoamine.pdf"
)

In [None]:
# List of enzyme genes
enzyme_genes = ["Trh", "Ddc", "ple", "Vmat", "Tbh", "Tdc2", "SerT", "DAT"]

# Create output folder for individual enzyme gene plots
enzyme_output_folder = "GeneTrends_Enzymes"
os.makedirs(enzyme_output_folder, exist_ok=True)

# Plot each enzyme gene individually
for gene in enzyme_genes:
    cr.pl.gene_trends(
        adata,
        model=model,
        genes=[gene],
        same_plot=True,
        ncols=1,
        time_key="latent_time",
        hide_cells=True,
        legend_loc=None,
        weight_threshold=(1e-3, 1e-3),
        n_jobs=1,
        save=f"{enzyme_output_folder}/GeneTrend_{gene}.pdf"
    )

# Combined plot (all enzyme genes together)
cr.pl.gene_trends(
    adata,
    model=model,
    genes=enzyme_genes,
    same_plot=True,
    ncols=2,
    time_key="latent_time",
    hide_cells=True,
    legend_loc=None,
    weight_threshold=(1e-3, 1e-3),
    n_jobs=1,
    save="CR_mature_neuronal_subset_CombKer_ModelGeneTrendsEnzymes_Monoamine.pdf"
)

In [None]:
# Table of cell counts per neuronal_annotation_fine
cell_counts = adata.obs["neuronal_annotation_fine"].value_counts().sort_index()
print(cell_counts)
cell_counts.to_csv("neuronal_annotation_fine_cell_counts.csv")

In [None]:
# compute putative drivers for the Beta trajectory
monoamine_drivers = g.compute_lineage_drivers(
    lineages=["Monoamine"],
    cluster_key="neuronal_annotation_fine")

# plot heatmap
cr.pl.heatmap(
    adata,
    model=model,  # use the model from before
    lineages="Monoamine",
    cluster_key="neuronal_annotation_fine",
    show_fate_probabilities=True,
    genes=monoamine_drivers.head(75).index,
    time_key="latent_time",
    figsize=(12, 16),
    show_all_genes=True,
    weight_threshold=(1e-3, 1e-3),
save='CR_mature_neuronal_subset_CombKer_HeatmapModelGeneTrends_Monoamine.pdf')

In [None]:
# Compute putative drivers for all terminal lineages
all_lineage_drivers = g.compute_lineage_drivers(lineages=lineages, cluster_key="neuronal_annotation_fine")

# Plot heatmap for all lineages (top 75 genes by min q-value across all lineages)
# Get top 75 genes with lowest min q-value across all lineages
min_qval = all_lineage_drivers[[f"{l}_qval" for l in lineages]].min(axis=1)
top_genes = min_qval.nsmallest(75).index

cr.pl.heatmap(
    adata,
    model=model,
    lineages=lineages,
    cluster_key="neuronal_annotation_fine",
    show_fate_probabilities=True,
    genes=top_genes,
    time_key="latent_time",
    figsize=(14, 18),
    show_all_genes=True,
    weight_threshold=(1e-3, 1e-3),
    save="CR_mature_neuronal_subset_CombKer_HeatmapModelGeneTrends_AllLineages.pdf"
)

In [None]:
# Compute putative drivers for the full neurogenic lineage (all relevant clusters)
# Here, we use all clusters in the neurogenic trajectory, e.g. all in 'neuronal_annotation_fine'
all_neurogenic_drivers = g.compute_lineage_drivers(
    lineages=None,  # None or all clusters, depending on CellRank version
    cluster_key="neuronal_annotation_fine"
)

# Select top 75 genes by lowest q-value across all clusters
qval_cols = [col for col in all_neurogenic_drivers.columns if col.endswith("_qval")]
min_qval = all_neurogenic_drivers[qval_cols].min(axis=1)
top_genes = min_qval.nsmallest(75).index

# Plot heatmap for the full neurogenic lineage
cr.pl.heatmap(
    adata,
    model=model,
    lineages=None,  # or all clusters, if required
    cluster_key="neuronal_annotation_fine",
    show_fate_probabilities=True,
    genes=top_genes,
    time_key="latent_time",
    figsize=(16, 20),
    show_all_genes=True,
    weight_threshold=(1e-3, 1e-3),
    save="CR_mature_neuronal_subset_CombKer_HeatmapModelGeneTrends_NeurogenicLineage.pdf"
)

In [None]:
# Define gene groups and output folders
gene_groups = {
    "Enzymes": ["Trh", "Ddc", "ple", "Vmat", "Tbh", "Tdc2", "SerT", "DAT", "Hdc"],
    "SerReceptors": ["5-HT1A", "5-HT1B", "5-HT2A", "5-HT2B", "5-HT7"],
    "DopReceptors": ["Dop1R1", "Dop1R2", "Dop2R", "DopEcR"],
    "OctReceptors": ["Oct-TyrR", "Octbeta1R", "Octbeta3R", "Oamb", "Octbeta2R", "Octalpha2R"],
    "HisReceptors": ["HisCl1"],
    "TyrReceptors": ["Oct-TyrR", "TyrR", "TyrRII"],
    "TF3": ["vvl", "CG32532", "Lmx1a", "Ets65A"]

}

for group, genes in gene_groups.items():
    
    os.makedirs(folder, exist_ok=True)
    # Plot each gene individually
    for gene in genes:
        cr.pl.gene_trends(
            adata,
            model=model,
            genes=[gene],
            same_plot=True,
            ncols=1,
            time_key="latent_time",
            hide_cells=True,
            legend_loc=None,
            weight_threshold=(1e-3, 1e-3),
            n_jobs=1,
            save=f"GeneTrend_{gene}.pdf"
        )
    # Combined plot for the group
    cr.pl.gene_trends(
        adata,
        model=model,
        genes=genes,
        same_plot=True,
        ncols=2,
        time_key="latent_time",
        hide_cells=True,
        legend_loc=None,
        weight_threshold=(1e-3, 1e-3),
        n_jobs=1,
        save=f"CR_mature_neuronal_subset_CombKer_ModelGeneTrends{group}_Monoamine.pdf"
    )