In [None]:
import scanpy as sc
import matplotlib.pyplot as plt
import warnings
import anndata
import pandas as pd
import numpy as np
from scipy.stats import pearsonr
from scipy.stats import spearmanr
import seaborn as sns
import decoupler as dc
warnings.filterwarnings("ignore")
import matplotlib.colors as mcolors

pd.set_option('display.max_columns', 500)
from pysankey2 import Sankey
from sklearn.preprocessing import MinMaxScaler
# R interface
from rpy2.robjects import pandas2ri
from rpy2.robjects import r
import rpy2.rinterface_lib.callbacks
import anndata2ri

plt.rcParams.update({
    'font.family': 'Arial'
})

pandas2ri.activate()
anndata2ri.activate()

%load_ext rpy2.ipython

In [None]:
adata16 = sc.read('Data/adata_d16_annotated.h5ad')
print('16')
print('X: ', adata16.X.min(), adata16.X.max())
print('log_transformed: ',adata16.layers['log_transformed'].min(), adata16.layers['log_transformed'].max())
print('counts: ',adata16.layers['counts'].min(), adata16.layers['counts'].max())

adata25 = sc.read('Data/adata_d25_annotated.h5ad')
print('25')
print('X: ',adata25.X.min(), adata25.X.max())
print('log_transformed: ',adata25.layers['log_transformed'].min(), adata25.layers['log_transformed'].max())
print('counts: ',adata25.layers['counts'].min(), adata25.layers['counts'].max())

adata_d50_d70 = sc.read('Data/adata_d50_d70_neurons.h5ad')
print('50 70')
print('X: ',adata_d50_d70.X.min(), adata_d50_d70.X.max())
print('log_transformed: ',adata_d50_d70.layers['log_transformed'].min(), adata_d50_d70.layers['log_transformed'].max())
print('counts: ',adata_d50_d70.layers['counts'].min(), adata_d50_d70.layers['counts'].max())


In [None]:

with plt.rc_context({ "figure.dpi": 300, "figure.figsize": (4,4) }):

    sc.pl.umap(adata16, color=['Cell_types'], ncols=4, use_raw=False, cmap='jet', frameon=False,size=18, layer='log_transformed')
    sc.pl.umap(adata16, color=['POMC','OTP','DLX6-AS1','STMN2'], ncols=4, use_raw=False, cmap='jet', frameon=False,size=18, layer='log_transformed')   

In [None]:
adata25.obs['Cell_types'] = adata25.obs['Cell_types'].astype(str)
cycling = pd.Series(list(adata25[adata25.obs['integrated_snn_res.0.85'].isin(['4','12'])].obs.index), dtype="string")
adata25.obs['Cell_types'].loc[cycling] = "Cycling"

with plt.rc_context({ "figure.dpi": 300, "figure.figsize": (4,4) }):

    sc.pl.umap(adata25, color=['Cell_types'], ncols=4, use_raw=False, cmap='jet', frameon=False,size=18, layer='log_transformed')
    sc.pl.umap(adata25, color=['STMN2','DIO2'], ncols=4, use_raw=False, cmap='jet', frameon=False,size=18, layer='log_transformed')

In [None]:
# Get the cells of interest
adata16 = adata16[adata16.obs.Cell_types.isin(['OTP+ neurons', 'POMC+ neurons','DLX6-AS1+ neurons'])]
adata16.obs['reactionID'] =  adata16.obs['diff_batch_2'].astype(str) + '|D16' 
adata16.obs['Cell_types'] = adata16.obs['Cell_types'].replace({'DLX6-AS1+ neurons': 'DLX6-AS1+','OTP+ neurons': 'OTP+', 'POMC+ neurons': 'POMC+'})

adata25 = adata25[adata25.obs.Cell_types.isin(['DLX6-AS1+ neurons','NR5A2/ONECUT1/3+ neurons', 'OTP+ neurons','POMC+ neurons'])]
adata25.obs['reactionID'] =  adata25.obs['reactionID'].astype(str) + '|D25' 
adata25.obs['Cell_types'] = adata25.obs['Cell_types'].replace({'DLX6-AS1+ neurons': 'DLX6-AS1+','OTP+ neurons': 'OTP+', 'POMC+ neurons': 'POMC+', 'NR5A2/ONECUT1/3+ neurons': 'NR5A2/ONECUT1/3+'})


adata_d50_d70.obs['reactionID'] =  adata_d50_d70.obs['reactionID'].astype(str) + '|D50_70' 


In [None]:
with plt.rc_context({ "figure.dpi": 300, "figure.figsize": (4,4) }):

    sc.pl.umap(adata16, color=['Cell_types'], ncols=4, use_raw=False, cmap='jet', frameon=False,size=18, layer='log_transformed', wspace=0.5)
    sc.pl.umap(adata25, color=['Cell_types'], ncols=4, use_raw=False, cmap='jet', frameon=False,size=18, layer='log_transformed', wspace=0.5)
    sc.pl.umap(adata_d50_d70, color=['Cell_types',], ncols=4, use_raw=False, cmap='jet', frameon=False,size=18, layer='log_transformed', wspace=0.5)
    


In [None]:
adata_concat = adata16.concatenate([adata25, adata_d50_d70], batch_key=None, join='inner')
adata_concat.obs = adata_concat.obs[[i for i in list(adata_concat.obs.columns) if i not in adata_concat.obs.columns[adata_concat.obs.isna().any()].tolist()]] # Keep columns present in both datasets

sc.pp.highly_variable_genes(adata_concat, n_top_genes=1500, inplace=True, batch_key='reactionID')
features = list(adata_concat[:, adata_concat.var.highly_variable].var_names)

In [None]:
%%R -i adata_concat -i features -o auroc -o auroc_col -o auroc_row
Csparse_validate = "CsparseMatrix_validate"

library(MetaNeighbor)
library(SummarizedExperiment)
library(Seurat)

sobj <- as.Seurat(adata_concat, counts = "counts", data = NULL)

sce_data = as.SingleCellExperiment(sobj)

auroc = MetaNeighborUS(var_genes = features, dat = sce_data, i = 'counts',fast_version=T,
                      study_id=sce_data$reactionID, cell_type = sce_data$Cell_types)

auroc_col = colnames(auroc)
auroc_row = rownames(auroc)

In [None]:
auroc_df = pd.DataFrame(auroc, index=auroc_row, columns=auroc_col)

group_rows = pd.Series(auroc_row).str.extract(r'^[^.]*\.(.*)')[0].values
group_cols = pd.Series(auroc_row).str.extract(r'^[^.]*\.(.*)')[0].values

# Group rows and columns and compute the mean
mean_auroc_df = (auroc_df.groupby(group_rows, axis=0).mean().groupby(group_cols, axis=1).mean())

mean_auroc_df.columns = [column.replace('.', '|') for column in mean_auroc_df.columns]
mean_auroc_df.index = [column.replace('.', '|') for column in mean_auroc_df.index]

mean_auroc_df

In [None]:
labels = mean_auroc_df.columns

with plt.rc_context({ "figure.dpi": 300 }):

    plt.rcParams["image.cmap"] = "coolwarm"

    fig, ax = plt.subplots()
    im = ax.imshow(mean_auroc_df)

    # Show all ticks and label them with the respective list entries
    ax.set_xticks(np.arange(len(labels)), labels=labels, size=6)
    ax.set_yticks(np.arange(len(labels)), labels=labels, size=6)

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")


    #ax.set_title("Harvest of local farmers (in tons/year)")
    fig.tight_layout()
    plt.show()

In [None]:
# Formulate the data for sankey plot
d16_cell_types = [cell_type for cell_type in mean_auroc_df.columns if 'D16' in cell_type]
d25_cell_types = [cell_type for cell_type in mean_auroc_df.columns if 'D25' in cell_type]
d50_70_cell_types = [cell_type for cell_type in mean_auroc_df.columns if 'D50_70' in cell_type]

d16_vs_d25 = mean_auroc_df[d16_cell_types].T[d25_cell_types].applymap(lambda x: 0 if x < 0.7 else x)
d25_vs_d50_70 = mean_auroc_df[d25_cell_types].T[d50_70_cell_types].applymap(lambda x: 0 if x < 0.7 else x)
d25_vs_d50_70 = mean_auroc_df[d25_cell_types].T[d50_70_cell_types].applymap(lambda x: 0 if x < 0.7 else x)



In [None]:
d16_vs_d25

In [None]:
d16_vs_d25

In [None]:
d25_vs_d50_70

# Reorder the columns
mean_auroc_df = mean_auroc_df[['D16|DLX6-AS1+ neurons', 'D16|OTP+ neurons', 'D16|POMC+ neurons',
       'D25|DLX6-AS1+ neurons', 'D25|NR5A2/ONECUT1/3+ neurons',
       'D25|OTP+ neurons', 'D25|POMC+ neurons', 'D50_70|DLX6-AS1+/FOXP2+','D50_70|GHRH+/PNOC+','D50_70|PNOC+/TAC3+','D50_70|NR5A2+/ONECUT1/3+','D50_70|AGRP+/SST+','D50_70|PCSK1+/ADGRL4+','D50_70|UNC13C+/OTP+','D50_70|POMC+/PRDM12+/LEPR+','D50_70|POMC+/CRABP1+/TRH+']]

d25_vs_d50_70 = d25_vs_d50_70[['D50_70|DLX6-AS1+/FOXP2+','D50_70|GHRH+/PNOC+','D50_70|PNOC+/TAC3+','D50_70|NR5A2+/ONECUT1/3+','D50_70|AGRP+/SST+','D50_70|PCSK1+/ADGRL4+','D50_70|UNC13C+/OTP+','D50_70|POMC+/PRDM12+/LEPR+','D50_70|POMC+/CRABP1+/TRH+']]


In [None]:
# Reorder the columns
mean_auroc_df = mean_auroc_df[['D16|DLX6-AS1+', 'D16|OTP+', 'D16|POMC+','D25|DLX6-AS1+', 'D25|NR5A2/ONECUT1/3+',
       'D25|OTP+', 'D25|POMC+', 'D50_70|Unassigned', 'D50_70|DLX6-AS1+/FOXP2+','D50_70|GHRH+/PNOC+','D50_70|PNOC+/NPFFR2+','D50_70|NR5A2+/ONECUT1/3+','D50_70|AGRP+/OTP+','D50_70|PCSK1+/ADGRL4+','D50_70|UNC13C+/OTP+','D50_70|POMC+/TBX3+/NR5A2+','D50_70|CRABP1+/TRH+']]

d25_vs_d50_70 = d25_vs_d50_70[['D50_70|Unassigned', 'D50_70|DLX6-AS1+/FOXP2+','D50_70|GHRH+/PNOC+','D50_70|PNOC+/NPFFR2+','D50_70|NR5A2+/ONECUT1/3+','D50_70|AGRP+/OTP+','D50_70|PCSK1+/ADGRL4+','D50_70|UNC13C+/OTP+','D50_70|POMC+/TBX3+/NR5A2+','D50_70|CRABP1+/TRH+']]

# Reformat data for sankey
celltype_index_dict =  dict(zip(mean_auroc_df.columns, list(range(len(mean_auroc_df.columns)))))

d16_vs_d25.index = [str(celltype_index_dict.get(item, item)) for item in d16_vs_d25.index]
d16_vs_d25.columns = [str(celltype_index_dict.get(item, item)) for item in d16_vs_d25.columns]
d25_vs_d50_70.index = [str(celltype_index_dict.get(item, item)) for item in d25_vs_d50_70.index]
d25_vs_d50_70.columns = [str(celltype_index_dict.get(item, item)) for item in d25_vs_d50_70.columns]

d16_vs_d25.index.name = "source"
d16_vs_d25.columns.name = "target"

d25_vs_d50_70.index.name = "source"
d25_vs_d50_70.columns.name = "target"

df = pd.concat([d16_vs_d25.stack().reset_index(name="value"),d25_vs_d50_70.stack().reset_index(name="value")]).reset_index(drop=True)

df = df[df.value>0]
df['boolean_value'] = 1

# Generate colors for correlation
norm = plt.Normalize(vmin=0.70, vmax=1.0)
cmap = plt.cm.coolwarm
df['color'] = [mcolors.to_hex(cmap(norm(value))) for value in df.value]

df

In [None]:
import plotly.graph_objs as go
with plt.rc_context({ "figure.dpi": 500, }):
    

    x = [0.1, 0.1, 0.1, 0.25, 0.250, 0.25, 0.25, 0.65, 0.65, 0.65,  0.65, 0.65, 0.65, 0.65, 0.65, 0.65, 0.65, 0.65,0.65] # 0.6
    y = [0, 0.1, 0.2, 0, 0.05, 0.1, 0.15, 0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55]
    
    x = [.001 if v==0 else .999 if v==1 else v for v in x]
    y = [.001 if v==0 else .999 if v==1 else v for v in y]

    link = dict(source=df.source.values.tolist(), target=df.target.values.tolist(), value=df.boolean_value.values.tolist(),
    color=df.color)
    node = dict(label=['' for item in list(celltype_index_dict.keys())], pad=25,thickness=1.2,
                    line=dict(color="black", width=2), x=x, y=y)


    chart = go.Sankey(link=link, node=node, arrangement="snap")
    fig = go.Figure(chart)
    fig.update_layout(font=dict(size=24, weight=600, color='black' , shadow='black'))
    fig.show()

In [None]:
from IPython.display import Image
fig.write_image("d16_25_50_cluster_trajectory1.pdf", format="pdf", width=1300, height=500, scale=10)


In [None]:
import plotly.graph_objs as go
with plt.rc_context({ "figure.dpi": 500, }):
    

    x = [0.1, 0.1, 0.1, 0.25, 0.250, 0.25, 0.25, 0.65, 0.65, 0.65,  0.65, 0.65, 0.65, 0.65, 0.65, 0.65, 0.65, 0.65,0.65] # 0.6
    y = [0, 0.1, 0.2, 0, 0.35, 0.62, 0.85, 0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55]
    
    x = [.001 if v==0 else .999 if v==1 else v for v in x]
    y = [.001 if v==0 else .999 if v==1 else v for v in y]

    link = dict(source=df.source.values.tolist(), target=df.target.values.tolist(), value=df.boolean_value.values.tolist(),
    color=df.color)
    node = dict(label=['' for item in list(celltype_index_dict.keys())], pad=25,thickness=1.2,
                    line=dict(color="black", width=2), x=x, y=y)


    chart = go.Sankey(link=link, node=node, arrangement="snap")
    fig = go.Figure(chart)
    fig.update_layout(font=dict(size=24, weight=600, color='black' , shadow='black'))
    fig.show()

In [None]:
from IPython.display import Image
fig.write_image("d16_25_50_cluster_trajectory.pdf", format="pdf", width=1300, height=500, scale=10)


In [None]:
import plotly.graph_objs as go
with plt.rc_context({ "figure.dpi": 500, }):
    

    x = [0.1, 0.1, 0.1, 0.25, 0.250, 0.25, 0.25, 0.65, 0.65, 0.65,  0.65, 0.65, 0.65, 0.65, 0.65, 0.65, 0.65, 0.65,0.65] # 0.6
    y = [0, 0.1, 0.2, 0, 0.05, 0.1, 0.15, 0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55]
    
    x = [.001 if v==0 else .999 if v==1 else v for v in x]
    y = [.001 if v==0 else .999 if v==1 else v for v in y]

    link = dict(source=df.source.values.tolist(), target=df.target.values.tolist(), value=df.boolean_value.values.tolist(),
    color=df.color)
    node = dict(label=[item.split('|', 1)[1] for item in list(celltype_index_dict.keys())], pad=25,thickness=1.2,
                    line=dict(color="black", width=2), x=x, y=y)


    chart = go.Sankey(link=link, node=node, arrangement="snap")
    fig = go.Figure(chart)
    fig.update_layout(font=dict(size=24, weight=600, color='black' , shadow='black'))
    fig.show()

In [None]:
from IPython.display import Image
fig.write_image("d16_25_50_cluster_trajectory.pdf", format="pdf", width=1300, height=500, scale=10)
