In [None]:
import anndata
import numpy as np
import seaborn as sb
import numpy as np
import warnings
import matplotlib.pyplot as plt
import pandas as pd
warnings.filterwarnings("ignore")
from genes2genes import Main
from genes2genes import ClusterUtils
from genes2genes import TimeSeriesPreprocessor
from genes2genes import PathwayAnalyser
from genes2genes import VisualUtils
import optbinning
import matplotlib.pyplot as plt
import pickle # save the data


#### 2. load expression data and create objs

# Make sure that each adata object has:
# (1) log normalized gene expression in adata.X
# (2) pseudotime estimates in adata.obs['time'] 

In [None]:

adata_ref = anndata.read_h5ad("./1_con_BranchA.h5ad")  # Reference dataset
adata_query = anndata.read_h5ad("./1_ko_BranchA.h5ad")  # Query dataset

mat_ref = pd.read_csv("./1_exp_con_BranchA.csv")
mat_query = pd.read_csv("./1_exp_ko_BranchA.csv")

adata_query.X = mat_query.transpose()
adata_ref.X = mat_ref.transpose()

print(min(adata_ref.obs['time']), max(adata_ref.obs['time']))
print(min(adata_query.obs['time']), max(adata_query.obs['time']))


#### density plot for the pseudotime

In [None]:
fig, ax = plt.subplots()  # 创建图形和轴对象
sb.kdeplot(
    adata_ref.obs['time'], 
    fill=True, 
    label='Control', 
    color='forestgreen', 
    ax=ax , legend= False
)
sb.kdeplot(
    adata_query.obs['time'], 
    fill=True, 
    label='Vcl cKO', 
    color='midnightblue', 
    ax=ax, legend= False
)

ax.tick_params(axis='both', labelsize=14) 
ax.set_xlabel('Pseudotime', fontsize=16)  
ax.set_ylabel('Density', fontsize=16)  
ax.legend(fontsize=12)  
plt.show()

#### 3. determine the number of discrete time points to align

In [None]:
from optbinning import ContinuousOptimalBinning

x = np.asarray(adata_ref.obs.time)
optb = ContinuousOptimalBinning(name='time', dtype="numerical")
optb.fit(x, x)
print(len(optb.splits))

x = np.asarray(adata_query.obs.time)
optb = ContinuousOptimalBinning(name='time', dtype="numerical")
optb.fit(x, x)
print(len(optb.splits))


In [None]:
n_bins = 10
sample = "final.annotation"
adata_query.obs["final.annotation"].unique

# define the joint colormap to use for both reference and query
# col = np.array(sb.color_palette('colorblind'))[range(5)]
# joint_cmap={'BranchA':col[0], 'Neuroblast':col[1]}
joint_cmap = {'BranchA': "#bca9f5", 'Neuroblast': "#fdcee6", "BP":"#6b853e"}
VisualUtils.plot_pseudotime_dists_with_interpolation_points(adata_ref, adata_query, n_bins)
VisualUtils.plot_celltype_barplot(adata_ref, n_bins, sample, joint_cmap, legend=True, plot_cell_counts = True)
VisualUtils.plot_celltype_barplot(adata_query, n_bins, sample, joint_cmap, legend=True, plot_cell_counts = True)

VisualUtils.plot_celltype_barplot(adata_ref, n_bins, sample, joint_cmap, legend=True)
VisualUtils.plot_celltype_barplot(adata_query, n_bins, sample, joint_cmap, legend=True)

In [None]:
########### save bindata for each cell
time_colname='time'
normalize = False
annotation_colname = "final.annotation"
vec = adata_query.obs["bin_ids"]
bin_edges = np.linspace(0, 1, num=n_bins)
bin_ids = np.digitize(vec, bin_edges, right=False) # use right=True if we don't need 1.0 cell to always be a single last bin 

info_query = adata_query.obs[['bin_ids', "time", "final.annotation"]]
info_ref = adata_ref.obs[['bin_ids', "time", "final.annotation"]]

info_query.to_csv("mutant_time_bin.csv")
info_ref.to_csv("control_time_bin.csv")

#### alignment

In [None]:
import pandas as pd
data = pd.read_csv("./DEGs.for.alignmnet.csv", header=None)

gene_list = data.iloc[:,0]

gene_ref = list(adata_ref.var_names)
gene_query = list(adata_query.var_names)
genes = list(set(gene_ref) & set(gene_query))
# genes = list(set(gene_ref) & set(gene_query) & set(gene_list))
print(len(genes), 'genes')

In [None]:

aligner = Main.RefQueryAligner(adata_ref, adata_query, genes, n_bins) 
aligner.align_all_pairs() 

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

def plot_alignment_path_on_given_matrix(mat, paths, cmap='viridis', num = 100):
    # Convert mat to NumPy array if it's a DataFrame
    mat = np.array(mat)  # Ensure mat is a 2D NumPy array
    
    fig, ax = plt.subplots(1, 1, figsize=(8, 8))
    sns.heatmap(mat, square=True, cmap=cmap, ax=ax, cbar=True, annot=False)  # Base heatmap
    
    # Annotate values greater than 100 with black font
    for i in range(mat.shape[0]):
        for j in range(mat.shape[1]):
            if mat[i, j] > num:  # Only annotate values > 100
                ax.text(j + 0.5, i + 0.5, f'{mat[i, j]:.0f}',  # Integer format
                        ha='center', va='center', color='black')  # Fixed black color
    
    # Plot paths
    for path in paths: 
        path_x = [p[0] + 0.5 for p in path]  # Row indices
        path_y = [p[1] + 0.5 for p in path]  # Column indices
        ax.plot(path_y, path_x, color='white', linewidth=6)
    
    # Customize labels and appearance
    ax.set_xlabel("Control", fontsize=16, fontweight='bold')
    ax.set_ylabel(r"$\bf{\it{Vcl}}$ $\bf{cKO}$", fontsize=16)
    ax.tick_params(axis='both', labelsize=14)
    ax.xaxis.tick_top()  # Move x-axis to top
    ax.xaxis.set_label_position('top')
    
    plt.show()  # Display the plot directly


average_alignment, alignment_path = ClusterUtils.get_cluster_average_alignments(aligner, aligner.gene_list)
mat = ClusterUtils.get_pairwise_match_count_mat(aligner, aligner.gene_list)

print('Average Alignment: ', VisualUtils.color_al_str(average_alignment), '(cell-level)')

aligner.average_alignment = average_alignment
plot_alignment_path_on_given_matrix(paths=[alignment_path], mat=mat, num=1000)

In [None]:
##### gene patterns along trajectory

genes = ["Egr1", "E2f1", "Klf7"]
for tf in genes:
    VisualUtils.plotTimeSeries(tf, aligner, plot_cells=True)
    gene_obj = aligner.results_map[tf]
    alignment_str = gene_obj.alignment_str
    print(tf + ":" + VisualUtils.color_al_str(alignment_str)) 


In [None]:
##### overall alignment 
df = aligner.get_stat_df()

#### alignment clustering

In [None]:
###### optimized the parameters
df.temp = ClusterUtils.run_clustering(aligner, metric='levenshtein', experiment_mode=True) 
# selected the threshold with higher score but fewer clusters

In [None]:
ClusterUtils.run_clustering(aligner, metric='levenshtein', DIST_THRESHOLD=0.6) 
ClusterUtils.visualise_clusters(aligner, n_cols = 4, figsize= (10,6))

In [None]:
ClusterUtils.print_cluster_average_alignments(aligner)