In [None]:
## 01 Trajectory inference by Palantir
import palantir
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
from cellrank.tl.kernels import CytoTRACEKernel,ConnectivityKernel
from cellrank.tl.estimators import GPCCA
import scanpy.external as sce
plt.rcParams['pdf.fonttype']=42
adata=sc.read("/path/to/subset_mVEC.h5ad")  ## subset mVEC cells
sc.pp.neighbors(adata,use_rep='X_pca_harmony')
sc.tl.umap(adata)

ck = ConnectivityKernel(adata)
ck.compute_transition_matrix()
g_pv = GPCCA(ck)
g_pv.compute_schur(n_components=10)
g_pv.plot_spectrum(real_only=True)
g_pv.compute_macrostates(n_states=10, cluster_key="aEC_state")
g_pv.plot_macrostates(discrete=True, legend_loc="right", size=100, basis="X_umap")
g_pv.set_terminal_states_from_macrostates(['R2'])
g_pv._set_initial_states_from_macrostates("caEC")
g_pv.compute_absorption_probabilities()
g_pv.plot_absorption_probabilities(same_plot=False)
start_cell=adata.obs.initial_states_probs.argmax()
start_cell=adata.obs.index[start_cell]
end_cell=adata.obs.terminal_states_probabilities.argmax()
end_cell=adata.obs.index[end_cell]

dm_res = palantir.utils.run_diffusion_maps(adata, n_components=5,pca_key='X_pca_harmony')
ms_data = palantir.utils.determine_multiscale_space(adata)
imputed_X = palantir.utils.run_magic_imputation(adata)
terminal_states = pd.Series(["R2"],index=[end_cell])

pr_res = palantir.core.run_palantir(
    adata, start_cell, num_waypoints=500, terminal_states=terminal_states)

## plotting
sc.set_figure_params(dpi=80)
palette1=['#E69F00', '#56B4E9', '#009E73', '#F0E442']
sc.pl.umap(adata,color=['palantir_pseudotime','aEC_states'],cmap='viridis',palette=palette1)
masks = palantir.presults.select_branch_cells(adata, q=.01, eps=.01)
palantir.plot.plot_trajectory(adata, "R2")

In [None]:
## 02 aEC states predicition on PDE3A/VIM

import pandas as pd
import scanpy as sc
import os
import joblib
import pickle
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import label_binarize
from sklearn.utils import resample
from sklearn.metrics import roc_auc_score, average_precision_score,roc_auc_score,RocCurveDisplay,auc

X = np.asarray(adata[:,['PDE3A','VIM']].X.todense().squeeze()).values
X=StandardScaler().fit_transform(X=X)
Y = adata.obs['aEC_states'].values
# Binarize labels for one-vs-rest
classes = np.unique(Y)
y_bin = label_binarize(Y, classes=classes)
n_classes = y_bin.shape[1]

# Split for training/testing
X_train, X_test, y_train, y_test = train_test_split(X, y_bin, test_size=0.3, random_state=42)

# Fit model
model = LogisticRegression(multi_class='multinomial', solver='lbfgs')
model.fit(X_train, np.argmax(y_train, axis=1))  # use original y labels

with open('LR.pkl','wb') as f:
    pickle.dump(model,f)

y_score = model.predict_proba(X_test)


## AUPR/AUROC  
def bootstrap_auroc(y_true, y_score, classes, n_bootstraps=100, seed=42):
    rng = np.random.RandomState(seed)
    y_true_bin = label_binarize(y_true, classes=classes)
    n_classes = len(classes)
    aucs = {label: [] for label in classes}
    
    for i in range(n_bootstraps):
        # Resample indices with replacement
        indices = rng.randint(0, len(y_true), len(y_true))
        if len(np.unique(y_true[indices])) < 2:
            continue  # Skip if not enough class variety
        y_true_resampled = y_true_bin[indices]
        y_score_resampled = y_score[indices]
        
        for j, label in enumerate(classes):
            try:
                auc = roc_auc_score(
                    y_true_resampled[:, j],
                    y_score_resampled[:, j]
                )
                aucs[label].append(auc)
            except ValueError:
                continue  # Skip if only one class present in resample

    summary = {}
    for label in classes:
        scores = np.array(aucs[label])
        mean = scores.mean()
        std = scores.std()
        ci_lower = np.percentile(scores, 2.5)
        ci_upper = np.percentile(scores, 97.5)
        summary[label] = {
            'mean': mean,
            'std': std,
            '95% CI': (ci_lower, ci_upper)
        }
    return summary

for i in range(nclasses):
    auc = roc_auc_score(y_true_bin[:,i], y_scores[:,i])
    aupr = average_precision_score(y_true_bin[:, i], y_scores[:, i])
    print(i, "AUROC:",auc, 'AUPR:', aupr)

result = bootstrap_auroc(y_bin, y_scores, classes)

for label in result:
    print(f"Class {label}: AUC = {result[label]['mean']:.3f} ± {result[label]['std']:.3f} "
          f"(95% CI: {result[label]['95% CI'][0]:.3f} - {result[label]['95% CI'][1]:.3f})")


In [None]:
## 03 DEG analysis

sc.tl.rank_genes_groups(adata,groupby='aEC_states',groups=['caEC','R0','R1','R2'],reference='caEC',method='wilcoxon')
for i in ['R0','R1','R2']:
    df=sc.get.rank_genes_groups_df(adata,group=[i])
    df.to_csv(i+"_vs_caEC.tsv",sep="\t",index=False)

def plot_MA(data,fc,exp,group,degs,adjust=True,**kwargs):
    fig, ax = plt.subplots()
    ax=sns.scatterplot(data=data,x=exp,y=fc,hue=group,linewidth=0,**kwargs)
    labeled_x,labeled_y=data.loc[degs,exp],data.loc[degs,fc]
    texts = [plt.text(labeled_x[i], labeled_y[i], degs[i]) for i in range(len(labeled_x))]
    if adjust:
        adjust_text(texts, labeled_x, labeled_y, arrowprops=dict(arrowstyle='-', color='black'),force_text=(0.2,0.2),expand_text=(0.1,0.1))
    plt.xlabel('scaled expression')                   # Set x-axis label
    plt.ylabel('Log2FC')
    return ax


from adjustText import adjust_text
import seaborn as sns
palette='coolwarm_r'
group='R1'
df=sc.get.rank_genes_groups_df(adata,group=[group])
df.columns=['names','scores',"lfc",'pvals','pvals_adj']
df['group']="genome"
df.index =df.names
mean_exp=pd.DataFrame(adata.X.todense().mean(0).squeeze()).transpose()
mean_exp.index=adata.var_names
df['exp']=mean_exp[0]
df.loc[(df['scores']>2.5),"group"]=group
df.loc[(df['scores']<-2.5),"group"]="caEC"
genestolabel=["EFNB2", "DLL4", "NOTCH1", "NOTCH4", "HEY1", "HEY2"]
plot_MA(df,fc="logfoldchanges",exp='mean_exp',degs=tolabel,group="group",palette=palette)
plt.ylim((-6,6))
plt.xlim((0,6))
