# 0. Import libraries

In [None]:
import pandas as pd
import numpy as np
import anndata as ad
import scanpy as sc
import matplotlib.pyplot as plt
import polars as pl

from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.model_selection import cross_validate
from sklearn.model_selection import StratifiedKFold
import seaborn as sns

# Functions

In [None]:
def plot_logFC_ranked_genes(list_dfs, labels, fig_rows, fig_columns, figsize, plot_title):
    """ 
    Plot the logfoldchanges of the ranked DEGs per group.
    
    Input:
    - list_dfs: list of the dataframes containing the results of the ranked genes. A dataframe per group
    - labels: series object with the label of each cell
    - fig_rows: number of rows for subplots
    - fig_columns: number of columns for subplots
    - figsize: size of the plot
    Output: plot of the logFC of the ranked genes per group
    """
    # Set up a 1x3 grid of subplots
    fig, axes = plt.subplots(fig_rows, fig_columns, figsize=figsize) 

    for i in range(len(labels.dtype.categories)):
        # Get group name
        group = labels.dtype.categories[i]
    
        # Sort genes on logfoldchange
        top_genes_sorted = list_dfs[i].sort_values('logfoldchanges') # sort on logfoldchange value

        # Create the bar plot
        axes[i].bar(
            top_genes_sorted['names'], 
            top_genes_sorted['logfoldchanges']
        )

        # Add axis labels
        axes[i].set_xlabel("Ranked gene name")
        axes[i].set_ylabel("Log Fold Change")
        axes[i].set_title("{} DEGs".format(group))
        # Rotate x-axis tick labels by 90 degrees for better readability
        axes[i].tick_params(rotation=90)

    fig.suptitle(plot_title)

In [None]:
def manually_filter_genes(df_ranked_genes, min_fold_change, max_adj_pvalue):
    """
    Function to manually filter the DEGs. Filter out genes with an insufficient logfoldchange or adjusted p value.

    Input: 
    - df_ranked_genes: dataframe with the results of sc.tl.rank_genes_groups() obtained using sc.get.rank_genes_groups_df().
    - min_fold_change: filtering criteria of minimum logfoldchange of DEGs
    - max_adj_pvalue: filtering criteria of maximum adjusted p-value of DEGs (i.e. maximum false discorvery rate)
    Output: 
    - df_ranked_genes_filtered: dataframe with the genes meeting the filtering criteria, including the logfoldchanges, pvalues etc.
    """
    df_ranked_genes_filtered = df_ranked_genes[abs(df_ranked_genes['logfoldchanges']) > min_fold_change] # minimum log fold change of 0.2
    df_ranked_genes_filtered = df_ranked_genes_filtered[abs(df_ranked_genes_filtered['pvals_adj']) < max_adj_pvalue] # maximum adjusted p value of 0.001
    # df_ranked_genes_filtered = df_ranked_genes_filtered[abs(df_ranked_genes_filtered['pvals_adj']) > 0] # remove genes with a p value of 0 because this looks like it is inaccurate?
    print('Number of filtered genes:',len(df_ranked_genes_filtered))
    
    return df_ranked_genes_filtered

In [None]:
def manual_calculation_logFC(adata_train, df_ranked_genes, group_name):

    group_cells = adata_train.obs['sample_type'] == group_name
    nan_genes = df_ranked_genes[df_ranked_genes['logfoldchanges'].isna()]
    corrected_logfc = {}

    for gene in nan_genes['names']:
        mean_group = adata_train[group_cells, gene].X.mean()
        mean_others = adata_train[~group_cells, gene].X.mean()

        # Check for division by zero
        if mean_group > 0 and mean_others > 0:
            corrected_logfc[gene] = np.log2(mean_group / mean_others)
        else:
            corrected_logfc[gene] = np.nan

    # Output dictionary of corrected logFC
    # print(corrected_logfc)

    # Replace NaN logfoldchanges with manually computed values
    df_ranked_genes_manual_logFC = df_ranked_genes
    for gene, logfc in corrected_logfc.items():
        df_ranked_genes_manual_logFC.loc[df_ranked_genes_manual_logFC['names'] == gene, 'logfoldchanges'] = logfc
    
    return df_ranked_genes_manual_logFC



In [None]:
def threshold_genes_for_signatures(df_ranked_genes, threshold):
    """
    Select the number desired number of DEGs for the gene signature.
    """
    df_ranked_genes.reset_index(inplace=True)
    df_ranked_genes_cutoff = df_ranked_genes.loc[:threshold-1,:]
    print(len(df_ranked_genes_cutoff))
    
    return df_ranked_genes_cutoff

In [None]:
def identify_gene_signatures(X_train, grouping_label, nr_genes, method, labels, min_fold_change, max_adj_pvalue, threshold):
    """
    Function to identify the differentially expressed genes (DEGs) cell type/group and filter them. The top DEGs are selected per group using sc.tl.rank_genes_groups with a Wilcoxon test. 
    The results of ranking the genes are stored in a dataframe per group. Subsequently, the DEGs are filtered on criteria of a minimum logfoldchange and a maximum false discorvery rate. 
    A list of the dataframes with filtered DEGs per group is returned.
    
    Input:
    - X_train: AnnData object of the training data
    - grouping_label: a string which is the annotation of an observation of the AnnData object; the cells are grouped by this observation.
    - nr_genes: the number of genes selected for the gene signature
    - method: statistical method to identify gene signatures 
    - labels: series object with the label of each cell
    - min_fold_change: filtering criteria of minimum logfoldchange of DEGs
    - max_adj_pvalue: filtering criteria of maximum adjusted p-value of DEGs (i.e. maximum false discorvery rate)
    Output: 
    - list_dfs_ranked_genes_filtered: list of dataframes. The dataframes should contain the filtered DEGs (retrieved using sc.tl.rank_genes_groups followed by sc.get.rank_genes_groups_df); one dataframe for each cell type/group.
    """

    ## Identify differentially expressed genes (DEGs) between groups for training data
    sc.tl.rank_genes_groups(
        X_train, 
        groupby=grouping_label, 
        n_genes=nr_genes,           # number of genes to rank / in gene signature
        method=method,              # statistical method
        key_added=method            # key label where the results are stored as observation in the AnnData object
    );

    ## Convert results of DEGs into a dataframe
    df_ranked_genes_cycling = sc.get.rank_genes_groups_df(X_train, group=X_train.obs[grouping_label].dtype.categories[0], key=method);
    # print('Number of DEGs of the cycling persisters that have a logfoldchange other than NaN:',len(df_ranked_genes_cycling))
    df_ranked_genes_moderate_cyclers = sc.get.rank_genes_groups_df(X_train, group=X_train.obs[grouping_label].dtype.categories[1], key=method);
    # print('Number of DEGs of the moderate cycling persisters that have a logfoldchange other than NaN:',len(df_ranked_genes_moderate_cyclers))
    df_ranked_genes_non_cycling = sc.get.rank_genes_groups_df(X_train, group=X_train.obs[grouping_label].dtype.categories[2], key=method);
    # print('Number of DEGs of the non-cycling persisters that have a logfoldchange other than NaN:',len(df_ranked_genes_non_cycling))
    
    ## Plot the logfoldchanges of the ranked genes
    list_dfs_ranked_genes = [df_ranked_genes_cycling, df_ranked_genes_moderate_cyclers, df_ranked_genes_non_cycling]
    # plot_logFC_ranked_genes(list_dfs_ranked_genes, labels, 1, 3, (30,5), 'Logfoldchanges of ranked DEGs before filtering')

    # Manually calculate logFC
    df_ranked_genes_cycling_manual_logFC = manual_calculation_logFC(X_train, df_ranked_genes_cycling, 'Cycling')
    df_ranked_genes_moderate_cyclers_manual_logFC = manual_calculation_logFC(X_train, df_ranked_genes_moderate_cyclers, 'Moderate_cyclers')
    df_ranked_genes_non_cycling_manual_logFC = manual_calculation_logFC(X_train, df_ranked_genes_non_cycling, 'Non-cycling')

    ## Manually filter the ranked DEGs per group based on the criteria  
    df_ranked_genes_cycling_filtered = manually_filter_genes(df_ranked_genes_cycling_manual_logFC, min_fold_change, max_adj_pvalue);
    df_ranked_genes_moderate_cyclers_filtered = manually_filter_genes(df_ranked_genes_moderate_cyclers_manual_logFC, min_fold_change, max_adj_pvalue);
    df_ranked_genes_non_cycling_filtered = manually_filter_genes(df_ranked_genes_non_cycling_manual_logFC, min_fold_change, max_adj_pvalue);

    # list_dfs_ranked_genes_filtered = [df_ranked_genes_cycling_filtered, df_ranked_genes_moderate_cyclers_filtered, df_ranked_genes_non_cycling_filtered]; # concatenate dataframes with the gene signatures
    # plot_logFC_ranked_genes(list_dfs_ranked_genes_filtered, labels, 1, 3, (30,5), 'Logfoldchanges of ranked DEGs after filtering')

    df_ranked_genes_cycling_filtered_cutoff = threshold_genes_for_signatures(df_ranked_genes_cycling_filtered,threshold)
    df_ranked_genes_moderate_cyclers_filtered_cutoff = threshold_genes_for_signatures(df_ranked_genes_moderate_cyclers_filtered,threshold)
    df_ranked_genes_non_cycling_filtered_cutoff = threshold_genes_for_signatures(df_ranked_genes_non_cycling_filtered,threshold)

    list_dfs_ranked_genes_filtered_cutoff = [df_ranked_genes_cycling_filtered_cutoff, df_ranked_genes_moderate_cyclers_filtered_cutoff, df_ranked_genes_non_cycling_filtered_cutoff]
    plot_logFC_ranked_genes(list_dfs_ranked_genes_filtered_cutoff, labels, 1, 3, (30,5), 'Logfoldchanges of ranked DEGs after filtering and thresholding')

    return list_dfs_ranked_genes_filtered_cutoff

In [None]:
# Calculate gene signature scores
def get_gene_signature_scores(adata, list_dfs_DEGs):
    """ 
    Calculate the gene signature scores for each group (Cycling, Moderate cyclers, Non-cycling) and concatenate them all into a dataframe.
    
    Inputs:
    - adata: AnnData object with normalized and log-transformed gene expression data
    - list_dfs_ranked_genes_filtered: list of dataframes. These dataframes contain the DEGs
    Output:
    - dataframe of the calculated gene signature scores per cell (rows are cells and columns are gene signature scores per group)
    """

    sc.tl.score_genes(adata, list_dfs_DEGs[0]['names'], score_name='Gene_sig_cycling')
    sc.tl.score_genes(adata, list_dfs_DEGs[1]['names'], score_name='Gene_sig_moderate_cyclers')
    sc.tl.score_genes(adata, list_dfs_DEGs[2]['names'], score_name='Gene_sig_non-cycling')

    # get dataframe of the gene signature scores
    gene_signatures_data = adata.obs[['Gene_sig_cycling','Gene_sig_moderate_cyclers','Gene_sig_non-cycling']]

    return gene_signatures_data

# 1. Read proprecessed data (AnnData object)

In [None]:
# Read preprocessed AnnData object
adata_preprocessed = ad.read_h5ad('/home/jolien/Notebooks/data/preprocessed_data_v2.h5ad')

# 2. Prepare data

### 2.1 Add psuedocount to logtransformed data

In [None]:
# Copy preprocessed data into a new AnnData object
adata_psuedocount = adata_preprocessed.copy()

# keep preprocessed data in a layer
adata_psuedocount.layers["preprocessed_Xdata"] = adata_preprocessed.X.copy()

# Put log_transformed data (before scaling) into the X matrix of the new adata_psuedocount object (this is needed for ranking the differentially expressed genes (because logtransformed is needed and it should be in the adata.X))
adata_psuedocount.X = adata_preprocessed.layers['log_transformed']

# Check minimum and maximum value in log_transformed matrix
print('Minimum value in the preprocessed (without scaling) count matrix:', adata_psuedocount.X.min())
print('Maximum value in the preprocessed (without scaling) count matrix:', adata_psuedocount.X.max())

# Add psuedocount = add 0.1 to the count matrix to prevent zeros
adata_psuedocount.X = adata_psuedocount.X+0.1 # add psuedocount as new layer in the AnnData object of the day 14 cells

print('Minimum value in the psuedocount matrix:', adata_psuedocount.X.min())
print('Maximum value in the psuedocount matrix:', adata_psuedocount.X.max())

### 2.2 Select persister cells (= cells at day 14)

In [None]:
# Get selection of adata object of only day 14 cells 
adata_psuedocount_day14 = adata_psuedocount[adata_psuedocount.obs['time_point']==14] 

# 3. Train classification model

### 3.1 Split data into training and test set

In [None]:
# Retrieve cell indices for stratified train-test split
train_indices, test_indices = train_test_split(
    np.arange(adata_psuedocount_day14.n_obs),  # indices of cells
    test_size=0.3,                 # test size
    random_state=42,               # to ensure same split in subsequent runs
    stratify=adata_psuedocount_day14.obs['sample_type']  # stratify by sample_type to keep group proportions
)

# Create train and test AnnData objects
adata_train = adata_psuedocount_day14[train_indices]
adata_test = adata_psuedocount_day14[test_indices]

# Get grouping label for each cell
y_train = adata_train.obs['sample_type']
y_test = adata_test.obs['sample_type']



In [None]:
print('Fraction non-cycling cells in training data {:.2f}'.format(len(adata_train[adata_train.obs['sample_type']=='Non-cycling'])/len(adata_train)))
print('Fraction moderate cycling cells in training data {:.2f}'.format(len(adata_train[adata_train.obs['sample_type']=='Moderate_cyclers'])/len(adata_train)))
print('Fraction cycling cells in training data {:.2f}'.format(len(adata_train[adata_train.obs['sample_type']=='Cycling'])/len(adata_train)))

### 3.2 Identify DEGs in training data set

In [None]:
# Settings
nr_genes = 100          # number of genes to rank / in gene signature
method = 'wilcoxon'     # statistical method

# Identify differentially expressed genes (DEGs) between groups for training data
sc.tl.rank_genes_groups(
    adata_train, 
    groupby='sample_type', 
    n_genes=nr_genes,           # number of genes to rank / in gene signature
    method=method,              # statistical method
    key_added=method
)


In [None]:
sc.pl.rank_genes_groups(adata_train, n_genes=25, sharey=False, key='wilcoxon')

In [None]:
sc.pl.rank_genes_groups_dotplot(adata_train, n_genes=10, key='wilcoxon', groupby='sample_type')

### 3.3 Process identified DEGs per group - remove genes with NaN logfoldchange

In [None]:
# View the top genes for each group
# Convert results to a DataFrame for easier inspection and remove genes with a NaN value for the log fold change

df_ranked_genes_cycling = sc.get.rank_genes_groups_df(adata_train, group='Cycling', key='wilcoxon')
print(df_ranked_genes_cycling.head(10))
df_ranked_genes_cycling.dropna(subset = ['logfoldchanges'], inplace=True)
print('Number of DEGs of the cycling persisters that have a logfoldchange other than NaN:',len(df_ranked_genes_cycling))

df_ranked_genes_moderate_cyclers = sc.get.rank_genes_groups_df(adata_train, group='Moderate_cyclers', key='wilcoxon')
df_ranked_genes_moderate_cyclers.dropna(subset = ['logfoldchanges'], inplace=True)
print('Number of DEGs of the moderate cycling persisters that have a logfoldchange other than NaN:',len(df_ranked_genes_moderate_cyclers))

df_ranked_genes_non_cycling = sc.get.rank_genes_groups_df(adata_train, group='Non-cycling', key='wilcoxon')
df_ranked_genes_non_cycling.dropna(subset = ['logfoldchanges'], inplace=True)
print('Number of DEGs of the non-cycling persisters that have a logfoldchange other than NaN:',len(df_ranked_genes_non_cycling))

In [None]:
# View the ranked genes of non-cycling persisters (with the genes having NaN logfoldchange removed)
df_ranked_genes_cycling.head(10)

### 3.4 Plot log FC of the ranked genes

In [None]:
# Plot the logfoldchanges of the ranked DEGs per group
list_dfs_ranked_genes = [df_ranked_genes_cycling, df_ranked_genes_moderate_cyclers, df_ranked_genes_non_cycling]
plot_logFC_ranked_genes(list_dfs_ranked_genes, y_train, 1, 3, (30,5), 'Logfoldchanges of ranked DEGs before filtering')


### 3.5 Filter DEGs

In [None]:
# # filter the DEGs
# sc.tl.filter_rank_genes_groups(
#     adata_train, 
#     key='wilcoxon',
#     min_fold_change=0,
#     min_in_group_fraction=0,
#     use_raw=False
# )

# filtered_ranked_genes = adata_train.uns['rank_genes_groups_filtered']

# df_filtered_genes = sc.get.rank_genes_groups_df(adata_train, group='Non-cycling', key='rank_genes_groups_filtered') # get filtering results in a dataframe
# df_filtered_genes.dropna(subset = ['names'], inplace=True) # remove filtered out genes (genes with NaN as name)
# df_filtered_genes


The sc.tl.filter_rank_genes_groups function resulted in all genes being filtered out, even when the criteria were removed. Looks like something is wrong. Therefore, I decided to manually filter the genes.

In [None]:
# Manually filter the ranked DEGs per group based on the criteria used by the authors of the persister cell paper (genes with NaN a logfoldchanges were already not considered anymore)
df_ranked_genes_cycling_filtered = manually_filter_genes(df_ranked_genes_cycling, 0.2, 0.001)
df_ranked_genes_moderate_cyclers_filtered = manually_filter_genes(df_ranked_genes_moderate_cyclers, 0.2, 0.001)
df_ranked_genes_non_cycling_filtered = manually_filter_genes(df_ranked_genes_non_cycling, 0.2, 0.001)




In [None]:
# Plot the logfoldchanges of the filtered ranked DEGs per group
list_dfs_ranked_genes_filtered = [df_ranked_genes_cycling_filtered, df_ranked_genes_moderate_cyclers_filtered, df_ranked_genes_non_cycling_filtered]
plot_logFC_ranked_genes(list_dfs_ranked_genes_filtered, y_train, 1, 3, (30,5), 'Logfoldchanges of ranked DEGs after filtering')

### 3.6 Perform previous steps again for manual calculation of logfoldchanges

In [None]:
# View the top genes for each group
# Convert results to a DataFrame for easier inspection and remove genes with a NaN value for the log fold change

df_ranked_genes_cycling = sc.get.rank_genes_groups_df(adata_train, group='Cycling', key='wilcoxon')
# df_ranked_genes_cycling.dropna(subset = ['logfoldchanges'], inplace=True)
print('Number of DEGs of the cycling persisters:',len(df_ranked_genes_cycling))

df_ranked_genes_moderate_cyclers = sc.get.rank_genes_groups_df(adata_train, group='Moderate_cyclers', key='wilcoxon')
# df_ranked_genes_moderate_cyclers.dropna(subset = ['logfoldchanges'], inplace=True)
print('Number of DEGs of the moderate cycling persisters:',len(df_ranked_genes_moderate_cyclers))

df_ranked_genes_non_cycling = sc.get.rank_genes_groups_df(adata_train, group='Non-cycling', key='wilcoxon')
# df_ranked_genes_non_cycling.dropna(subset = ['logfoldchanges'], inplace=True)
print('Number of DEGs of the non-cycling persisters:',len(df_ranked_genes_non_cycling))

In [None]:
## Check the expression of a gene with NaN logfoldchange --> in this case SAT1 gene of which is differentially expressed in non-cycling cells but has a NaN logfoldchange

# Check expression of the gene in the non-cycling group
mean_SAT1_noncycling = np.mean(adata_train[adata_train.obs['sample_type']=='Non-cycling','SAT1'].X)
std_SAT1_noncycling = np.std(adata_train[adata_train.obs['sample_type']=='Non-cycling','SAT1'].X)
print(f"The mean SAT1 expression in non_cycling cells {mean_SAT1_noncycling:.2f}")
print(f"The standard deviation of SAT1 expression in non_cycling cells {std_SAT1_noncycling:.2f}")
print('\n')

# Check expression of the gene in the other groups
mean_SAT1_not_noncycling = np.mean(adata_train[adata_train.obs['sample_type']!='Non-cycling','SAT1'].X)
std_SAT1_not_noncycling = np.std(adata_train[adata_train.obs['sample_type']!='Non-cycling','SAT1'].X)
print(f"The mean SAT1 expression in the other cells {mean_SAT1_not_noncycling:.2f}")
print(f"The standard deviation of SAT1 expression in the other cells {std_SAT1_not_noncycling:.2f}")
print('\n')

# Check the logfold change
print(f"The manually calculated log fold change is {np.log2(mean_SAT1_noncycling/mean_SAT1_not_noncycling):.2f}")

In [None]:
def manual_calculation_logFC(adata_train, df_ranked_genes, group_name):

    group_cells = adata_train.obs['sample_type'] == group_name
    nan_genes = df_ranked_genes[df_ranked_genes['logfoldchanges'].isna()]
    corrected_logfc = {}

    for gene in nan_genes['names']:
        mean_group = adata_train[group_cells, gene].X.mean()
        mean_others = adata_train[~group_cells, gene].X.mean()

        # Check for division by zero
        if mean_group > 0 and mean_others > 0:
            corrected_logfc[gene] = np.log2(mean_group / mean_others)
        else:
            corrected_logfc[gene] = np.nan

    # Output dictionary of corrected logFC
    # print(corrected_logfc)

    # Replace NaN logfoldchanges with manually computed values
    df_ranked_genes_manual_logFC = df_ranked_genes
    for gene, logfc in corrected_logfc.items():
        df_ranked_genes_manual_logFC.loc[df_ranked_genes_manual_logFC['names'] == gene, 'logfoldchanges'] = logfc
    
    return df_ranked_genes_manual_logFC



In [None]:
# Manually calculate logFC
df_ranked_genes_cycling_manual_logFC = manual_calculation_logFC(adata_train, df_ranked_genes_cycling, 'Cycling')
df_ranked_genes_moderate_cyclers_manual_logFC = manual_calculation_logFC(adata_train, df_ranked_genes_moderate_cyclers, 'Moderate_cyclers')
df_ranked_genes_non_cycling_manual_logFC = manual_calculation_logFC(adata_train, df_ranked_genes_non_cycling, 'Non-cycling')

In [None]:
# Check manually calculated logFCs
df_ranked_genes_cycling_manual_logFC.head(10)

In [None]:
# View the top genes for each group
# Convert results to a DataFrame for easier inspection and remove genes with a NaN value for the log fold change

df_ranked_genes_cycling_manual_logFC.dropna(subset = ['logfoldchanges'], inplace=True)
print('Number of DEGs of the cycling persisters:',len(df_ranked_genes_cycling_manual_logFC))

df_ranked_genes_moderate_cyclers_manual_logFC.dropna(subset = ['logfoldchanges'], inplace=True)
print('Number of DEGs of the moderate cycling persisters:',len(df_ranked_genes_moderate_cyclers_manual_logFC))

df_ranked_genes_non_cycling_manual_logFC.dropna(subset = ['logfoldchanges'], inplace=True)
print('Number of DEGs of the non-cycling persisters:',len(df_ranked_genes_non_cycling_manual_logFC))

In [None]:
# Plot the logfoldchanges of the ranked DEGs per group
list_dfs_ranked_genes_manual_logFC = [df_ranked_genes_cycling_manual_logFC, df_ranked_genes_moderate_cyclers_manual_logFC, df_ranked_genes_non_cycling_manual_logFC]
plot_logFC_ranked_genes(list_dfs_ranked_genes_manual_logFC, y_train, 1, 3, (30,5), 'Manually calculated logfoldchanges of ranked DEGs before filtering')

In [None]:
# Manually filter the ranked DEGs per group based on the criteria used by the authors of the persister cell paper (genes with NaN a logfoldchanges were already not considered anymore)
df_ranked_genes_cycling_filtered = manually_filter_genes(df_ranked_genes_cycling_manual_logFC, 0.2, 0.001)
df_ranked_genes_moderate_cyclers_filtered = manually_filter_genes(df_ranked_genes_moderate_cyclers_manual_logFC, 0.2, 0.001)
df_ranked_genes_non_cycling_filtered = manually_filter_genes(df_ranked_genes_non_cycling_manual_logFC, 0.2, 0.001)




In [None]:
# Plot the logfoldchanges of the filtered ranked DEGs per group
list_dfs_ranked_genes_filtered = [df_ranked_genes_cycling_filtered, df_ranked_genes_moderate_cyclers_filtered, df_ranked_genes_non_cycling_filtered]
plot_logFC_ranked_genes(list_dfs_ranked_genes_filtered, y_train, 1, 3, (30,5), 'Logfoldchanges of ranked DEGs after filtering')

In [None]:
# Select the top x genes per gene set to ensure equal number of genes per gene set
nr_genes_threshold = 75

df_ranked_genes_cycling_filtered.reset_index(inplace=True)
df_ranked_genes_cycling_filtered_cutoff = df_ranked_genes_cycling_filtered.loc[:nr_genes_threshold-1,:]

df_ranked_genes_moderate_cyclers_filtered.reset_index(inplace=True)
df_ranked_genes_moderate_cyclers_filtered_cutoff = df_ranked_genes_moderate_cyclers_filtered.loc[:nr_genes_threshold-1,:]

df_ranked_genes_non_cycling_filtered.reset_index(inplace=True)
df_ranked_genes_non_cycling_filtered_cutoff = df_ranked_genes_non_cycling_filtered.loc[:nr_genes_threshold-1,:]

print(len(df_ranked_genes_cycling_filtered_cutoff), len(df_ranked_genes_moderate_cyclers_filtered_cutoff), len(df_ranked_genes_non_cycling_filtered_cutoff))

In [None]:
# Plot the logfoldchanges of the filtered ranked DEGs per group
list_dfs_ranked_genes_filtered_cutoff = [df_ranked_genes_cycling_filtered_cutoff, df_ranked_genes_moderate_cyclers_filtered_cutoff, df_ranked_genes_non_cycling_filtered_cutoff]
plot_logFC_ranked_genes(list_dfs_ranked_genes_filtered_cutoff, y_train, 1, 3, (30,5), 'Logfoldchanges of ranked DEGs after filtering')

In [None]:
top_genes_sorted = df_ranked_genes_non_cycling_filtered_cutoff.sort_values('logfoldchanges') # sort on logfoldchange value

plt.figure(figsize=(12, 6))

# Create the bar plot
plt.bar(
    top_genes_sorted['names'], 
    top_genes_sorted['logfoldchanges']
        )

# Add axis labels
plt.xlabel("Ranked gene name", fontsize=14)
plt.ylabel("Log Fold Change", fontsize=14)
plt.title("Non-cycling", fontsize=20)
plt.tick_params(rotation=90)
# plt.savefig('/home/jolien/Notebooks/gene_signature/figures/Final_top75_logfoldchange_non_cycling_manually_filtered.png')


In [None]:
# Create table of top ranked log fold changes per group
top_genes = pd.DataFrame({
    'Non-cycling': df_ranked_genes_non_cycling_filtered_cutoff.loc[:10,'names'].reset_index(drop=True),
    'Moderate cyclers': df_ranked_genes_moderate_cyclers_filtered_cutoff.loc[:10,'names'].reset_index(drop=True),
    'Cycling': df_ranked_genes_cycling_filtered_cutoff.loc[:10,'names'].reset_index(drop=True)    
})

top_genes

### 3.6 Train classification model

In [None]:
# Calculate gene signature scores
gene_signatures_train = get_gene_signature_scores(adata_train, list_dfs_ranked_genes_filtered_cutoff)
gene_signatures_test = get_gene_signature_scores(adata_test, list_dfs_ranked_genes_filtered_cutoff)

In [None]:
# Initialize the classifier
rf_classifier = RandomForestClassifier(random_state=42)

# Train the model
rf_classifier.fit(gene_signatures_train, y_train)

# Predict on the test set
y_pred = rf_classifier.predict(gene_signatures_test) 

In [None]:
# Evaluate the performance
print("Accuracy Score:", accuracy_score(y_test, y_pred))

print("Classification Report:")
print(classification_report(y_test, y_pred))

In [None]:
print("Confusion Matrix:")

# Get and reshape confusion matrix
conf_matrix = confusion_matrix(y_test, y_pred)
conf_matrix = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:, np.newaxis] # Noramlize values to get ratios instead of absolute counts. Equal to recall

# Plot confusion matrix
# sns.set(font_scale=1) # Adjust to fit
sns.heatmap(conf_matrix, annot=True, cmap='Blues', fmt='g')

# Add labels to the plot
class_names = ['Cycling','Moderate_cyclers','Non-cycling']
tick_marks = np.arange(len(class_names)) + 0.5
plt.xticks(tick_marks, class_names, rotation=25, fontsize=12)
plt.yticks(tick_marks, class_names, rotation=0, fontsize=12)
plt.xlabel('Predicted', fontsize=14)
plt.ylabel('Actual', fontsize=14)
           
# Save and show figure
# plt.savefig('figures/confusion_matrix.png')
plt.show()

In [None]:
# Get predicitions of class probabilities 
probs = rf_classifier.predict_proba(gene_signatures_test) 

In [None]:
# Convert probabilities into dataframe and add predicted class
df_classification_probs = pd.DataFrame(probs, columns = 'prob_'+y_test.dtype.categories)
df_classification_probs['Prediction'] = y_pred
df_classification_probs

In [None]:
# Investigate predictions of class probabilities per predicted label

# Loop over the three persister cell categories
for i in range(len(y_train.dtype.categories)): 
    group = y_train.dtype.categories[i] # get group name
    df_group_probs = df_classification_probs[df_classification_probs['Prediction']==group] # get the probabilities for the cells of the investigated group
    mean_prob = np.mean(df_group_probs['prob_'+group]) # get the mean probability per group
    
    # print(group,'- min probability of the prediction for the class {:.2f}'.format(np.min(df_group_probs[group]))) 
    # print(group,'- max probability of the prediction for the class {:.2f}'.format(np.max(df_group_probs[group]))) 
    print(group,'- mean probability of the prediction for the class {:.2f}'.format(mean_prob))

In [None]:
# Predict cell fate for other cells

# calculate the gene signature score per group for complete data set
gene_signatures_data_complete = get_gene_signature_scores(adata_psuedocount, list_dfs_ranked_genes_filtered)

# Predict on the cells from other time points
y_pred_complete = rf_classifier.predict(gene_signatures_data_complete) 

# assign predicted cell fates to new obs in the adata object
adata_psuedocount.obs['cell_fate_prediction']=y_pred_complete
adata_psuedocount.obs

In [None]:
### Predicted label/cell type
# Group and count occurrences
grouped_data = adata_psuedocount.obs.groupby(['time_point','cell_fate_prediction']).size().reset_index(name='count')
# print(grouped_data)
# Pivot dataframe for percentages
pivot = grouped_data.pivot(index='time_point', columns='cell_fate_prediction', values='count').fillna(0)
# print(pivot)
# Calculate percentages
pivot_percentage = pivot.div(pivot.sum(axis=1), axis=0) * 100
print(pivot_percentage)


### Real label/cell type
# Group and count occurrences
grouped_data_day14 = adata_psuedocount_day14.obs.groupby(['time_point','sample_type']).size().reset_index(name='count')
# Pivot dataframe for percentages
pivot_day14 = grouped_data_day14.pivot(index='time_point', columns='sample_type', values='count').fillna(0)
# Calculate percentages
pivot_day14_percentage = pivot_day14.div(pivot_day14.sum(axis=1), axis=0) * 100
print(pivot_day14_percentage)

### Plot the stacked barplot
import matplotlib.cm as cm
from matplotlib.gridspec import GridSpec

# Create a colormap instance
cmap = cm.get_cmap("PuBu")
# Get three equally spaced colors, avoiding the lightest shade
colors = cmap([0.7, 0.5, 0.3])  # Adjust these values to pick darker shades
# Colors for the real labels
colors_real_labels = ['#1f77b4','#ff7f0e','#2ca02c']

# Create a figure with subplots. Use GridSpec to define subplot sizes
fig = plt.figure(figsize=(12, 6))
gs = GridSpec(1, 3, width_ratios=[4, 1, 0.1])  # 3 parts: 3/4 of space for first, 1/4 for second, 0.1 for spacing

# Create subplots
ax1 = fig.add_subplot(gs[0])  # First subplot (wider)
ax2 = fig.add_subplot(gs[1])  # Second subplot (narrower)

# Set the axes background to white
ax1.set_facecolor('white')
ax2.set_facecolor('white')

# Add a black border around each plot
for spine in ax1.spines.values():
    spine.set_edgecolor('black')
for spine in ax2.spines.values():
    spine.set_edgecolor('black')

pivot_percentage.plot(kind="bar", stacked=True, ax=ax1, color=colors, fontsize=12)
ax1.set_xlabel('Day of measurement', fontsize=14)
ax1.set_ylabel('Percentage of cells (%)', fontsize=14)
ax1.set_title('Predicted cell fates', fontsize=16)
ax1.tick_params(rotation=0)

pivot_day14_percentage.plot(kind="bar", stacked=True, ax=ax2, color=colors_real_labels, fontsize=12)
ax2.set_xlabel('Day of measurement', fontsize=14)
ax2.set_ylabel('Percentage of cells (%)', fontsize=14)
ax2.set_title('Actual cell fate', fontsize=16)
ax2.tick_params(rotation=0)
ax1.set_facecolor('white')

plt.tight_layout()
# plt.savefig('/home/jolien/Notebooks/gene_signature/figures/Distribution_predicted_vs_real_classes.png')
plt.show()

In [None]:
# # UMAP plot colored by predicted cell fate (predicted based on gene signature) and sample type
# sc.tl.umap(adata_psuedocount,random_state=123)
# sc.pl.umap(adata_psuedocount, color=['sample_type','cell_fate_prediction'], save="UMAP_predicted_cell_fate.png")

In [None]:
# Plot force-directed graph with PAGA graph as initial cluster position - colored by predicted cell fate
sc.pl.draw_graph(adata_psuedocount, color=['cell_fate_prediction'])#, save="_PAGA_predicted_cell_fate.png")

# Cross validation

In [None]:
# Define parameters
n_splits = 5            # Number of folds for cross-validation
nr_genes = 150          # Number of genes to rank / in gene signature         
method = 'wilcoxon'     # Statistical method to identify gene signatures    
nr_DEGs_threshold = 75

# Stratified K-Folds cross-validator --> keeps class distribution similar in each fold
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)

# Placeholders 
all_DEGs = []                           # Placeholder for DEGs from each fold
all_models = []                         # Placeholder for classification model from each fold
all_predictions = []                    # Placeholder for classification predictions from each fold
all_fold_accuracies = []                # Placeholder for accuracies from each fold
all_fold_classification_reports = []    # Placeholder for classification reports from each fold
y_tests=[]                              # Placeholder for labels of test data set from each fold


### Model development
# Loop over each fold in StratifiedKFold 
for train_index, test_index in skf.split(adata_psuedocount_day14, adata_psuedocount_day14.obs['sample_type']):
    
    ## Step 1: Split data in train and test set
    X_train, X_test = adata_psuedocount_day14[train_index], adata_psuedocount_day14[test_index]
    y_train, y_test = adata_psuedocount_day14.obs['sample_type'][train_index], adata_psuedocount_day14.obs['sample_type'][test_index]
    y_tests.append(y_test) # store the real labels of the test set to enable making a confusion matrix of the best model

    ## Step 2: Identify differentially expressed genes (DEGs) between groups for training data and filter the genes based on the criteria of the authors of the persister cell paper
    list_dfs_ranked_genes_filtered = identify_gene_signatures(X_train, 'sample_type', nr_genes, method, y_train, 0.2, 0.001, nr_DEGs_threshold)
    all_DEGs.append(list_dfs_ranked_genes_filtered)

    ## Step 3: Calculate gene signature scores based on DEGs
    gene_signatures_train = get_gene_signature_scores(X_train, list_dfs_ranked_genes_filtered)
    gene_signatures_test = get_gene_signature_scores(X_test, list_dfs_ranked_genes_filtered)

    ## Step 4: Train classification model (Random Forest model)
    rf_classifier = RandomForestClassifier(random_state=42)             # Initialize the classifier
    rf_classifier.fit(gene_signatures_train, y_train)                   # Train the model
    all_models.append(rf_classifier)

    ## Step 5: Predict on the test set
    y_pred = rf_classifier.predict(gene_signatures_test)                                # predict class on the test set
    # Get predicitions of class probabilities 
    probs = rf_classifier.predict_proba(gene_signatures_test)                           # predict class probabilities
    df_classification_probs = pd.DataFrame(probs, columns = 'prob_'+y_test.dtype.categories)    # convert probabilities into dataframe 
    df_classification_probs['Prediction'] = y_pred                                      # add predicted class in a new column
    all_predictions.append(df_classification_probs)

    ## Step 6: Evaluate the performance and store the results
    # Accuracy
    accuracy = accuracy_score(y_test, y_pred)                           # accuracy of the predictions
    all_fold_accuracies.append(accuracy)
    # Classification report
    report = classification_report(y_test, y_pred, output_dict=True)    # classification report, including precision, recall, f1 scores of the classification model
    report = pd.DataFrame(report)                                       # convert results into dataframe
    all_fold_classification_reports.append(report)


### Model predictions on cells of other days
# Select model with the highest accuracy
best_model_index = all_fold_accuracies.index(max(all_fold_accuracies)) # index of the model with highest accuracy
best_model = all_models[best_model_index]
print(f'Accuracy of the best model: {all_fold_accuracies[best_model_index]}')
print(f'Classification report of the best model: {all_fold_classification_reports[best_model_index]}')

# Calculate the gene signature score per group for complete data set - using the gene signature of the best model
gene_signatures_data_complete = get_gene_signature_scores(adata_psuedocount, all_DEGs[best_model_index])

# Predict on the cells from other time points
y_pred_complete = best_model.predict(gene_signatures_data_complete) 

# Assign predicted cell fates to new obs in the adata object
adata_psuedocount.obs['cell_fate_prediction']=y_pred_complete

In [None]:
print(all_fold_accuracies)
all_fold_classification_reports[0]

In [None]:
for model_index in range(5):
    print(model_index)


    # Get and reshape confusion matrix
    conf_matrix = confusion_matrix(y_tests[model_index], all_predictions[model_index]['Prediction'])
    conf_matrix = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:, np.newaxis] # Noramlize values to get ratios instead of absolute counts

    # Plot confusion matrix
    sns.heatmap(conf_matrix, annot=True, cmap='Blues', fmt='g')

    # Add labels to the plot
    class_names = ['Cycling','Moderate_cyclers','Non-cycling']
    tick_marks = np.arange(len(class_names)) + 0.5
    plt.xticks(tick_marks, class_names, rotation=25)
    plt.yticks(tick_marks, class_names, rotation=0)
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
            
    # Save and show figure
    # plt.savefig('figures/confusion_matrix_best_model.png')
    plt.show()

In [None]:
print("Confusion Matrix of the best model from the crossvalidation:")

# Get and reshape confusion matrix
conf_matrix_best_model = confusion_matrix(y_tests[best_model_index], all_predictions[best_model_index]['Prediction'])
conf_matrix_best_model = conf_matrix_best_model.astype('float') / conf_matrix_best_model.sum(axis=1)[:, np.newaxis] # Noramlize values to get ratios instead of absolute counts

# Plot confusion matrix
sns.heatmap(conf_matrix_best_model, annot=True, cmap='Blues', fmt='g')

# Add labels to the plot
class_names = ['Cycling','Moderate_cyclers','Non-cycling']
tick_marks = np.arange(len(class_names)) + 0.5
plt.xticks(tick_marks, class_names, rotation=25)
plt.yticks(tick_marks, class_names, rotation=0)
plt.xlabel('Predicted')
plt.ylabel('Actual')
           
# Save and show figure
plt.savefig('figures/confusion_matrix_best_model.png')
plt.show()

In [None]:
# all_predictions[0][all_predictions[0]['Prediction'] != all_predictions[1]['Prediction']]

In [None]:
# all_predictions[1][all_predictions[0]['Prediction'] != all_predictions[1]['Prediction']]

### Visualize predicted classes with best model

In [None]:
### Predicted label/cell type
# Group and count occurrences
grouped_data = adata_psuedocount.obs.groupby(['time_point','cell_fate_prediction']).size().reset_index(name='count')
# print(grouped_data)
# Pivot dataframe for percentages
pivot = grouped_data.pivot(index='time_point', columns='cell_fate_prediction', values='count').fillna(0)
# print(pivot)
# Calculate percentages
pivot_percentage = pivot.div(pivot.sum(axis=1), axis=0) * 100
print(pivot_percentage)


### Real label/cell type
# Group and count occurrences
grouped_data_day14 = adata_psuedocount_day14.obs.groupby(['time_point','sample_type']).size().reset_index(name='count')
# Pivot dataframe for percentages
pivot_day14 = grouped_data_day14.pivot(index='time_point', columns='sample_type', values='count').fillna(0)
# Calculate percentages
pivot_day14_percentage = pivot_day14.div(pivot_day14.sum(axis=1), axis=0) * 100
print(pivot_day14_percentage)

### Plot the stacked barplot
import matplotlib.cm as cm
from matplotlib.gridspec import GridSpec

# Create a colormap instance
cmap = cm.get_cmap("PuBu")
# Get three equally spaced colors, avoiding the lightest shade
colors = cmap([0.3, 0.5, 0.7])  # Adjust these values to pick darker shades

# Create a figure with subplots. Use GridSpec to define subplot sizes
fig = plt.figure(figsize=(12, 6))
gs = GridSpec(1, 3, width_ratios=[4, 1, 0.1])  # 3 parts: 3/4 of space for first, 1/4 for second, 0.1 for spacing

# Create subplots
ax1 = fig.add_subplot(gs[0])  # First subplot (wider)
ax2 = fig.add_subplot(gs[1])  # Second subplot (narrower)

pivot_percentage.plot(kind="bar", stacked=True, ax=ax1, color=colors, xlabel='Day of measurment', ylabel='Percentage of cells (%)', title='Distribution pedicted classes per time point')
pivot_day14_percentage.plot(kind="bar", stacked=True, ax=ax2, cmap='viridis', xlabel='Day of measurment', ylabel='Percentage of cells (%)', title='Distribution real classes at day 14')

plt.tight_layout()
plt.savefig('/home/jolien/Notebooks/gene_signature/figures/Distribution_predicted_vs_real_classes_best_model.png')

In [None]:
# UMAP plot colored by predicted cell fate (predicted based on gene signature) and sample type
sc.tl.umap(adata_psuedocount,random_state=123)
sc.pl.umap(adata_psuedocount, color=['sample_type','cell_fate_prediction'], save="UMAP_predicted_cell_fate_best_model.png")

# Save AnnData object with predicted classes

In [None]:
adata_preprocessed.obs['Predicted_cell_fate'] = adata_psuedocount.obs['cell_fate_prediction'] # copy cell fate predictions to the original adata object for saving (psuedocount data and gene signature scores are not needed to be stored)

In [None]:
# adata_preprocessed.write('/home/jolien/Notebooks/data/preprocessed_data_v2_with_predicted_class_v2.h5ad')