### 0. Import libraries

In [None]:
import pandas as pd
import numpy as np
import polars as pl
import anndata as ad

import scanpy as sc

import matplotlib.pyplot as plt
import seaborn as sns

import scipy.stats as stats

import bbknn
# from numpy import cov
# from scipy.stats import pearsonr

#import os

In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
plt.rcParams["savefig.dpi"] = 300
plt.rcParams["figure.figsize"] = [6, 4.5]

### Load metadata

In [None]:
df_metaData_with_lineage = pd.read_csv('/data/benchmarks/scRNAseq_persisters/GSE150949_metaData_with_lineage.txt', sep="\t")
df_metaData_with_lineage

### Analyzing metadata

In [None]:
nr_cells_total = len(df_metaData_with_lineage)
nr_cells_no_barcode = sum(df_metaData_with_lineage['lineage_barcode'].isnull())
nr_cells_multiple_barcodes = sum(df_metaData_with_lineage['lineage_barcode'].str.contains(',', na=False))

print('The total number of cells =',nr_cells_total)
print('The number of cells without a lineage barcode =',nr_cells_no_barcode, 'This is equal to ', round((nr_cells_no_barcode/nr_cells_total)*100,1),'%')
print('The number of cells with multiple lineage barcodes =', nr_cells_multiple_barcodes,'This is equal to ', round((nr_cells_multiple_barcodes/nr_cells_total)*100,1),'%')

In [None]:
# Check mitochondrial fraction of cells
print('The number of cells with >0.1 mitochondrial fraction is =', len(df_metaData_with_lineage[df_metaData_with_lineage['percent.mito']>0.1]))
# check for cells with <1000 genes
print('The number of cells with <1000 genes is =', len(df_metaData_with_lineage[df_metaData_with_lineage['nGene']<1000]))
# check for cells with >4200 genes
print('The number of cells with >4200 genes is =', len(df_metaData_with_lineage[df_metaData_with_lineage['nGene']>4200]))

Since there are no cells with >0.1 mitochondrial fraction or with <1000 or >4200 genes, it looks like this data is already preprocessed before by Oren et al. (2021).

### Preprocessing metadata

In [None]:
copy_df =df_metaData_with_lineage.copy() # copy of dataframe to make additions

# replace sample_type label: from 14_high to non-cycling etc. to avoid confusion
copy_df = copy_df.replace('14_high', 'Non-cycling')
copy_df = copy_df.replace('14_med', 'Moderate_cyclers')
copy_df = copy_df.replace('14_low', 'Cycling')

In [None]:
# add column for info about the fate of the lineage at day 14
copy_df['fate_day_14'] = np.nan # create empty column

# Put 'Multi lineage' label in fate_day_14 column for cells that have multiple lineages
multi_barcode_indices = df_metaData_with_lineage['lineage_barcode'].str.contains(',', na=False)
copy_df.loc[multi_barcode_indices, 'fate_day_14'] = 'Multiple lineages'

# get index of cycling and non-cycling cells
index_non_cycling = copy_df.index[copy_df['sample_type']=='Non-cycling']
index_moderate_cyclers = copy_df.index[copy_df['sample_type']=='Moderate_cyclers']
index_cycling = copy_df.index[copy_df['sample_type']=='Cycling']

In [None]:
# Find barcodes of day 14 cells grouped per cell fate 
def get_unique_barcodes(df, indices_list):
    """Function to obtain a series object of the unique lineage barcodes of cells measured at day 14, as categoricals."""

    barcodes = df.loc[indices_list, 'lineage_barcode'] # extract lineage barcodes of day 14 cells from a population with the same cell fate
    barcodes = barcodes.astype('category') # convert to categories
    barcodes = barcodes.cat.categories # create an object containing all unique lineage barcodes (with the category data type)
    
    return barcodes

barcodes_non_cycling = get_unique_barcodes(copy_df,index_non_cycling) # barcoddes from day 14 cells categorized as non-cycling 
barcodes_moderate_cyclers = get_unique_barcodes(copy_df,index_moderate_cyclers )# barcoddes from day 14 cells categorized as moderate cyclers
barcodes_cycling = get_unique_barcodes(copy_df,index_cycling) # barcoddes from day 14 cells categorized as cycling 

In [None]:
# Find barcodes common between each pair of groups
common_noncycling_cycling = barcodes_non_cycling.intersection(barcodes_cycling)
common_noncycling_moderatecyclers = barcodes_non_cycling.intersection(barcodes_moderate_cyclers)
common_cycling_moderatecyclers = barcodes_cycling.intersection(barcodes_moderate_cyclers)

# Remove common barcodes from each group
unique_barcodes_non_cycling = barcodes_non_cycling.difference(common_noncycling_cycling.union(common_noncycling_moderatecyclers)) 
unique_barcodes_cycling = barcodes_cycling.difference(common_noncycling_cycling.union(common_cycling_moderatecyclers))
unique_barcodes_moderatecyclers = barcodes_moderate_cyclers.difference(common_noncycling_moderatecyclers.union(common_cycling_moderatecyclers))

# Combine all common barcodes --> multi fate lineages 
multifate_barcodes = common_noncycling_cycling.union(common_noncycling_moderatecyclers).union(common_cycling_moderatecyclers)


In [None]:
# get indices of cells with lineage barcodes per group
all_non_cyclers_indices = copy_df['lineage_barcode'].isin(unique_barcodes_non_cycling)
all_moderatecyclers_indices = copy_df['lineage_barcode'].isin(unique_barcodes_moderatecyclers)
all_cyclers_indices = copy_df['lineage_barcode'].isin(unique_barcodes_cycling)
all_multifate_indices = copy_df['lineage_barcode'].isin(multifate_barcodes)

# enter fate in cell fate column
copy_df.loc[all_non_cyclers_indices, 'fate_day_14'] = 'Non-cycling'
copy_df.loc[all_moderatecyclers_indices, 'fate_day_14'] = 'Moderate_cyclers'
copy_df.loc[all_cyclers_indices, 'fate_day_14'] = 'Cycling'
copy_df.loc[all_multifate_indices, 'fate_day_14'] = 'Multi-fate'

copy_df


### 1. Load count matrix data & create to AnnData object

In [None]:
# Load data using polars (=more effective/efficient than pandas)
df_pc9_count_matrix = pl.read_csv('/data/benchmarks/scRNAseq_persisters/GSE150949_pc9_count_matrix.csv')

In [None]:
df_pc9_count_matrix.head(10)

In [None]:
gene_names = df_pc9_count_matrix[:, 0].to_list() # Extract gene names (=first column)
df_pc9_count_matrix_without_genenames = df_pc9_count_matrix[:, 1:] # Exclude first column which containes the gene names

cell_names = df_pc9_count_matrix_without_genenames.columns # Extract names of the cells

numpy_count_matrix = df_pc9_count_matrix_without_genenames.to_numpy()  # Convert to a numpy matrix to enable conversion to AnnData object

# Create AnnData object
adata = ad.AnnData(X=numpy_count_matrix.T,
                   var=pd.DataFrame(index=gene_names),
                   obs=pd.DataFrame(index=cell_names))

In [None]:
adata.X

In [None]:
adata

So, the number of cells = 56419 and the number of genes = 22166

### Enter relevant metadata to the AnnData object

In [None]:
# Enter relavant metadata to the AnnData object
list_clone_size = df_metaData_with_lineage['clone_size'].to_list()

# Get lineage barcode in adata object
adata.obs['lineage_barcode']=df_metaData_with_lineage['lineage_barcode']

# Get time points as categorical in adata object
time_points_cat = df_metaData_with_lineage.time_point.astype('category') # convert dtype from int64 to category (for plotting lateron)
adata.obs['time_point'] = time_points_cat # add categorical time points to adata object
adata

# Get cell fate of lineage at day 14 in adata object
fate_day_14_cat = copy_df.fate_day_14.astype('category') # convert dtype to category (for plotting lateron)
adata.obs['fate_day_14'] = fate_day_14_cat # add categorical time points to adata object
adata

# Get time points as categorical in adata object, including cell fate categories for day 14 cells
sample_type_cat = copy_df.sample_type.astype('category') # convert dtype to category (for plotting lateron)
adata.obs['sample_type'] = sample_type_cat # add categorical time points to adata object
adata

In [None]:
# Only consider genes with more than 1 count
sc.pp.filter_genes(adata, min_counts=1)
adata

Apparently there were no zero-count genes

### 3. Normalization

In [None]:
# Normalize gene expression matrix with total UMI count per cell
adata.X = adata.X.astype('float64') # Convert the main data matrix to float64, because normalization was not possible with int64 values
sc.pp.normalize_per_cell(adata, key_n_counts='n_counts_all')

### 4. Identification of highly variable genes

Removing non-variable genes reduces the calculation time during the GRN reconstruction and simulation steps. It also improves the overall accuracy of the GRN inference by removing noisy genes. Using the top 2000~3000 variable genes is recommended.

In [None]:
# Select top 2000 highly-variable genes
filter_result = sc.pp.filter_genes_dispersion(adata.X,
                                              flavor='cell_ranger',
                                              n_top_genes=2000,
                                              log=False)

# Subset the genes
adata = adata[:, filter_result.gene_subset]

# Renormalize after filtering - making the total expression per cell equal across the dataset
sc.pp.normalize_per_cell(adata)

In [None]:
# import seaborn as sns

# adata.var['variability_means'] = filter_result.means
# adata.var['variability_dispersions'] = filter_result.dispersions
# adata.var['variability_highly_variable'] = filter_result.gene_subset

# top10_genes = adata.var[adata.var['variability_highly_variable']].nlargest(10, 'variability_dispersions').index.tolist()


# # Plot all genes, highlighting highly variable genes
# plt.figure(figsize=(10, 6))
# sns.scatterplot(data=adata.var, x='variability_means', y='variability_dispersions', hue='variability_highly_variable', palette=['gray', 'red'], s=10)

# # Label the top 10 highly variable genes
# for gene in top10_genes:
#     plt.text(adata.var.loc[gene, 'variability_means'], adata.var.loc[gene, 'variability_dispersions'], gene, fontsize=8, color='black', ha='right')


In [None]:
adata.X

### 5. Log transformation

In [None]:
# keep raw count data before log transformation
adata.raw = adata
adata.layers["raw_count"] = adata.raw.X.copy()

# Log transformation 
sc.pp.log1p(adata) # The "log1p" function means taking the natural logarithm of (1 + X) for each value in the expression matrix, the addition of 1 ensures all values, including zeros, are log-transformed without creating NaN values

# Keep log_transformed data before scaling
adata.layers["log_transformed"] = adata.X.copy()

# Scaling 
sc.pp.scale(adata)

adata

In [None]:
adata.obs

### 6. PCA and neighbor calculations

In [None]:
# PCA
sc.tl.pca(adata, svd_solver='arpack')

# Diffusion map
# sc.pp.neighbors(adata, n_neighbors=4, n_pcs=20)
bbknn.bbknn(adata, batch_key='time_point', neighbors_within_batch=4, n_pcs=20)

sc.tl.draw_graph(adata, random_state=123)
sc.pl.draw_graph(adata, color="sample_type")



In [None]:
adata

In [None]:
sc.tl.diffmap(adata)
# Calculate neihbors again based on diffusionmap
# sc.pp.neighbors(adata, n_neighbors=10, use_rep='X_diffmap')
bbknn.bbknn(adata, batch_key='time_point', neighbors_within_batch=10, use_rep='X_diffmap')

sc.tl.draw_graph(adata, random_state=123)
sc.pl.draw_graph(adata, color="sample_type")

Diffusion map is applied to denoise the graph.

In [None]:
# Access the variance explained by each PC
variance_ratio = adata.uns['pca']['variance_ratio']

# Create an elbow plot
plt.figure(figsize=(8, 5))
plt.plot(range(1, len(variance_ratio) + 1), variance_ratio, marker='o')
plt.xlabel('Principal Component')
plt.ylabel('Variance Explained')
plt.title('Elbow Plot for PCA')
plt.grid(True)
# plt.savefig('/home/jolien/Notebooks/data_preprocessing/figures/PCA_elbow_plot.png')
plt.show()

In [None]:
# Store PC1 and PC2 in adata.obs
adata.obs['PC1'] = adata.obsm['X_pca'][:,0] # First principal component
adata.obs['PC2'] = adata.obsm['X_pca'][:,1] # Second principal component
adata.obs['PC3'] = adata.obsm['X_pca'][:,2] # Third principal component
adata.obs['PC4'] = adata.obsm['X_pca'][:,3] # Fourth principal component
adata.obs['PC5'] = adata.obsm['X_pca'][:,4] # Fifth principal component
adata.obs['PC6'] = adata.obsm['X_pca'][:,5] # Sixt principal component

In [None]:
# Get explained variance percentage for PC1 and PC2
expl_var_pc1 = adata.uns['pca']['variance_ratio'][0]*100
expl_var_pc2 = adata.uns['pca']['variance_ratio'][2]*100

# Plot PC1 vs PC2
plt.figure(figsize=(8, 6))
plt.scatter(adata.obs['PC1'], adata.obs['PC2'], s=5) 
plt.xlabel('PC1 ({:.1f}%)'.format(expl_var_pc1))
plt.ylabel('PC2 ({:.1f}%)'.format(expl_var_pc2))
plt.title('PCA Plot of PC1 vs PC2')
plt.show()

In [None]:
# Extract PC1 and PC2 for plotting per sample_type group
# time = 0
pc1_0 = adata.obsm['X_pca'][copy_df['sample_type']=='0', 0]  # First principal component
pc2_0 = adata.obsm['X_pca'][copy_df['sample_type']=='0', 1]  # Second principal component

# time = 3
pc1_3 = adata.obsm['X_pca'][copy_df['sample_type']=='3', 0]  # First principal component
pc2_3 = adata.obsm['X_pca'][copy_df['sample_type']=='3', 1]  # Second principal component

# time = 7
pc1_7 = adata.obsm['X_pca'][copy_df['sample_type']=='7', 0]  # First principal component
pc2_7 = adata.obsm['X_pca'][copy_df['sample_type']=='7', 1]  # Second principal component

# time = cycling (14_low)
pc1_14l = adata.obsm['X_pca'][copy_df['sample_type']=='Cycling', 0]  # First principal component
pc2_14l = adata.obsm['X_pca'][copy_df['sample_type']=='Cycling', 1]  # Second principal component
# time = moderate cyclers (14_med)
pc1_14m = adata.obsm['X_pca'][copy_df['sample_type']=='Moderate_cyclers', 0]  # First principal component
pc2_14m = adata.obsm['X_pca'][copy_df['sample_type']=='Moderate_cyclers', 1]  # Second principal component
# time = non-cycling (14_high)
pc1_14h = adata.obsm['X_pca'][copy_df['sample_type']=='Non-cycling', 0]  # First principal component
pc2_14h = adata.obsm['X_pca'][copy_df['sample_type']=='Non-cycling', 1]  # Second principal component

# Plot PC1 vs PC2
plt.figure(figsize=(8, 6))
plt.scatter(pc1_0, pc2_0, c='g', s=5) # plot day 0 cells
plt.scatter(pc1_3, pc2_3, c='r', s=5) # plot day 3 cells
plt.scatter(pc1_7, pc2_7, c='k', s=5) # plot day 7 cells
plt.scatter(pc1_14l, pc2_14l, c='b', s=5) # plot day cycling (14_low) cells
plt.scatter(pc1_14m, pc2_14m, c='c', s=5) # plot day moderate cycler (14_med) cells
plt.scatter(pc1_14h, pc2_14h, c='m', s=5) # plot day non-cycling (14_high) cells

plt.xlabel('PC1 ({:.1f}%)'.format(expl_var_pc1))
plt.ylabel('PC2 ({:.1f}%)'.format(expl_var_pc2))
plt.title('PCA Plot of PC1 vs PC2 colored by sample type')
plt.legend(['Day 0','Day 3','Day 7','Day 14 - cycling','Day 14 - moderate cyclers','Day 14 - non-cycling']) 
# plt.savefig('/home/jolien/Notebooks/data_preprocessing/figures/PCA_colored_by_sample_type.png')
plt.show()

In [None]:
plt.figure(figsize=(8, 6))

sns.scatterplot(data=adata.obs, x="PC1", y="PC2", hue="sample_type", size=5)

plt.xlabel('PC1 ({:.1f}%)'.format(expl_var_pc1))
plt.ylabel('PC2 ({:.1f}%)'.format(expl_var_pc2))
plt.title('PCA Plot of PC1 vs PC2 colored by sample type')

# plt.savefig('/home/jolien/Notebooks/data_preprocessing/figures/PCA_colored_by_sample_type_v2.png')

In [None]:
plt.figure(figsize=(8, 6))

sns.kdeplot(data=adata.obs, x="PC1", y="PC2", hue="sample_type") # contour plot

plt.xlabel('PC1 ({:.1f}%)'.format(expl_var_pc1))
plt.ylabel('PC2 ({:.1f}%)'.format(expl_var_pc2))
plt.title('PCA density contour plot for sample types')

# plt.savefig('/home/jolien/Notebooks/data_preprocessing/figures/PCA_contour_colored_by_sample_type.png')

In [None]:
plt.figure(figsize=(8, 6))

sns.kdeplot(data=adata.obs, x="PC1", y="PC2", hue="sample_type",alpha=0.3, fill=True) # contour plot filled

plt.xlabel('PC1 ({:.1f}%)'.format(expl_var_pc1))
plt.ylabel('PC2 ({:.1f}%)'.format(expl_var_pc2))
plt.title('PCA density contour plot for sample types')

# plt.savefig('/home/jolien/Notebooks/data_preprocessing/figures/PCA_contour_colored_by_sample_type_filled.png')

Check correlation PCs and time point

In [None]:
fig, axes = plt.subplots(3, 2, figsize=(20, 18), sharey=True)

pc_nr = 1

for i, ax_row in enumerate(axes):
    for j, ax in enumerate(ax_row):
        
        # Plot boxplots of PC per sample type
        sns.boxplot(x='sample_type', y='PC'+str(pc_nr), data=adata.obs, ax=ax) 
        
        # Add titles
        ax.set_title(f'PC{pc_nr} grouped per sample type')
        
        # Add axis label
        ax.set_ylabel(f"PC{pc_nr}")
        
        # Increment PC number
        pc_nr += 1

# fig.savefig('/home/jolien/Notebooks/data_preprocessing/figures/boxplots_PCs_per_sample_type.png')

In [None]:
# check whether the difference in PC1 of group 0 and 3 is significant
kruskal_result = stats.kruskal(
    adata.obs['PC1'][adata.obs['sample_type'] == '0'],
    adata.obs['PC1'][adata.obs['sample_type'] == '3']
)

print("Kruskal-Wallis p-value:", kruskal_result.pvalue)


I don't know if I can use it when we don't have more digits for the p-value

In [None]:
# cov(np.transpose(adata.obs['time_point']),np.transpose(adata.obs['PC1'])) # gives the same output as when I don't apply the transpose
cov_PC1 = np.cov(adata.obs['time_point'],adata.obs['PC1'])
print('covariance PC1 and time point:')
print(cov_PC1)
corr_PC1, _ = stats.pearsonr(adata.obs['time_point'],adata.obs['PC1'])
print('Pearsons correlation PC1 and time point: %.3f' % corr_PC1,'\n')

cov_PC2 = np.cov(adata.obs['time_point'],adata.obs['PC2'])
print('covariance PC2 and time point:')
print(cov_PC2)
corr_PC2, _ = stats.pearsonr(adata.obs['time_point'],adata.obs['PC2'])
print('Pearsons correlation PC2 and time point: %.3f' % corr_PC2,'\n')

cov_PC3 = np.cov(adata.obs['time_point'],adata.obs['PC3'])
print('covariance PC3 and time point:')
print(cov_PC3)
corr_PC3, _ = stats.pearsonr(adata.obs['time_point'],adata.obs['PC3'])
print('Pearsons correlation PC3 and time point: %.3f' % corr_PC3,'\n')

cov_PC4 = np.cov(adata.obs['time_point'],adata.obs['PC4'])
print('covariance PC4 and time point:')
print(cov_PC4)
corr_PC4, _ = stats.pearsonr(adata.obs['time_point'],adata.obs['PC4'])
print('Pearsons correlation PC4 and time point: %.3f' % corr_PC4,'\n')

cov_PC5 = np.cov(adata.obs['time_point'],adata.obs['PC5'])
print('covariance PC5 and time point:')
print(cov_PC5)
corr_PC5, _ = stats.pearsonr(adata.obs['time_point'],adata.obs['PC5'])
print('Pearsons correlation PC5 and time point: %.3f' % corr_PC5,'\n')

cov_PC6 = np.cov(adata.obs['time_point'],adata.obs['PC6'])
print('covariance PC6 and time point:')
print(cov_PC6)
corr_PC6, _ = stats.pearsonr(adata.obs['time_point'],adata.obs['PC6'])
print('Pearsons correlation PC6 and time point: %.3f' % corr_PC6,'\n')

PC1 and time point have a positive covariance (off diagnal values in covariance matrix) and a correlation of 0.6 which indicates there is a dependency between PC1 and time point (moderate positive relationship).

### 8. Dimensionality reduction using PAGA and force-directed graphs as well as UMAP

In [None]:
adata

In [None]:
# Run Louvain clustering
sc.tl.louvain(adata)

In [None]:
sc.tl.paga(adata, groups='louvain')

In [None]:
plt.rcParams["figure.figsize"] = [6, 4.5]
sc.pl.paga(adata)

In [None]:
# Calculate force-directed graph with PAGA graph as initial cluster position
sc.tl.draw_graph(adata, init_pos='paga', random_state=123) # Random seed to ensure consistency of plot for different runs

In [None]:
# Calculate UMAP 
sc.tl.umap(adata,random_state=123)

### 9. Visualization

In [None]:
# Plot force-directed graph with PAGA graph as initial cluster position
sc.pl.draw_graph(adata, color=["louvain", "time_point","sample_type"], legend_loc='on data', save="_PAGA_batch_correction_all_groupings.png")

In [None]:
# UMAP plot
sc.pl.umap(adata, color=['louvain','time_point','sample_type'],save="_UMAP_batch_correction_all_groupings.png")

In [None]:
# Plot UMAP with coloring of cell fate which is based on the lineage barcodes
sc.pl.umap(adata, color='fate_day_14',save="_UMAP_batch_correction_fate_day14.png")

In [None]:
sample_type_palette = {
    '0': '#F560A6',  # Pink
    '3': '#91307F',  # Purple
    '7': '#2D0059',  # Dark purple
    'Cycling': '#1f77b4',  # Blue
    'Moderate_cyclers': '#ff7f0e',  # Orange
    'Non-cycling': '#2ca02c',  # Green
}

sample_type_palette_time = {
    0: '#F560A6',  # Pink
    3: '#91307F',  # Purple
    7: '#2D0059',  # Dark purple
    14: '#5b5b5b',  # Grey
}

# Plot force-directed graph with PAGA graph as initial cluster position - legend next to plot
fig, axes = plt.subplots(1, 3, figsize=(20, 5))
sc.pl.draw_graph(adata, color='louvain', legend_loc='on data', ax=axes[0], show=False)                                         # First plot with legend on data
sc.pl.draw_graph(adata, color='time_point', ax=axes[1], palette=sample_type_palette_time, show=False)                          # Second plot 
sc.pl.draw_graph(adata, color='sample_type', ax=axes[2], palette=sample_type_palette, show=False)                              # Third plot 

# Save the combined plot
plt.tight_layout()

### 11. Save AnnData object

In [None]:
adata.write('/home/jolien/Notebooks/data/preprocessed_data_bbknn_batchcorrection.h5ad')

### 12. Load AnnData object

In [None]:
# Read preprocessed AnnData object
adata_preprocessed = sc.read_h5ad('/home/jolien/Notebooks/data/preprocessed_data_bbknn_batchcorrection.h5ad')

In [None]:
sample_type_palette = {
    '0': '#F560A6',  # Pink
    '3': '#91307F',  # Purple
    '7': '#2D0059',  # Dark purple
    'Cycling': '#1f77b4',  # Blue
    'Moderate_cyclers': '#ff7f0e',  # Orange
    'Non-cycling': '#2ca02c',  # Green
}

sample_type_palette_time = {
    0: '#F560A6',  # Pink
    3: '#91307F',  # Purple
    7: '#2D0059',  # Dark purple
    14: '#5b5b5b',  # Grey
}

# Plot force-directed graph with PAGA graph as initial cluster position - legend next to plot
fig, axes = plt.subplots(1, 3, figsize=(20, 5))
sc.pl.draw_graph(adata_preprocessed, color='louvain', legend_loc='on data', ax=axes[0], show=False)                                         # First plot with legend on data
sc.pl.draw_graph(adata_preprocessed, color='time_point', ax=axes[1], palette=sample_type_palette_time, show=False)                          # Second plot 
sc.pl.draw_graph(adata_preprocessed, color='sample_type', ax=axes[2], palette=sample_type_palette, show=False)                              # Third plot 
# sc.pl.draw_graph(adata_preprocessed, color='Predicted_cell_fate', ax=axes[2], palette=sample_type/_palette, show=False)                      # Third plot 

# Save the combined plot
plt.tight_layout()
plt.savefig("/home/jolien/Notebooks/data_preprocessing/batch_correction/bbknn/figures/PAGA_batch_correction_all_groupings.png")