### 0. Import libraries

In [None]:
import pandas as pd
import numpy as np
import polars as pl
import anndata as ad
import scanpy as sc
import matplotlib.pyplot as plt
import scib

In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
plt.rcParams["savefig.dpi"] = 300
plt.rcParams["figure.figsize"] = [6, 4.5]

# 1. Load data

### Load metadata

In [None]:
# File paths
metadata_file_path = '/data/benchmarks/scRNAseq_persisters/GSE150949_metaData_with_lineage.txt'
metadata_seurat_file_path = '/data/benchmarks/scRNAseq_persisters/metadata_seuratobject.csv'
count_matrix_file_path = '/data/benchmarks/scRNAseq_persisters/GSE150949_pc9_count_matrix.csv'

In [None]:
# Load metadata file
df_metaData_with_lineage = pd.read_csv(metadata_file_path, sep="\t")

# Load metadata from the Seurat object to retrieve the majority fate
df_metadata_seurat = pd.read_csv(metadata_seurat_file_path)

# Add majority fate of the cells retrieved from the metadata of seurat object (from R data file) to the dataframe 
df_metaData_with_lineage['majority_fate'] = df_metadata_seurat['majority_fate'].tolist()

##### Analyzing metadata

In [None]:
# Check mitochondrial fraction of cells
print('The number of cells with >0.1 mitochondrial fraction is =', len(df_metaData_with_lineage[df_metaData_with_lineage['percent.mito']>0.1]))
# check for cells with <1000 genes
print('The number of cells with <1000 genes is =', len(df_metaData_with_lineage[df_metaData_with_lineage['nGene']<1000]))
# check for cells with >4200 genes
print('The number of cells with >4200 genes is =', len(df_metaData_with_lineage[df_metaData_with_lineage['nGene']>4200]))

Since there are no cells with >0.1 mitochondrial fraction or with <1000 or >4200 genes, it looks like this data is already preprocessed before by Oren et al. (2021).

##### Preprocessing metadata

In [None]:
copy_df =df_metaData_with_lineage.copy() # copy of dataframe to make additions

# replace sample_type label: from 14_high to non-cycling etc. to avoid confusion
copy_df = copy_df.replace('14_high', 'Non-cycling')
copy_df = copy_df.replace('14_med', 'Moderate_cyclers')
copy_df = copy_df.replace('14_low', 'Cycling')

### Load count matrix data (scRAN-seq data) & convert into AnnData object

In [None]:
# Load data using polars (=more effective/efficient than pandas)
df_pc9_count_matrix = pl.read_csv(count_matrix_file_path)

In [None]:
df_pc9_count_matrix.head(10)

In [None]:
gene_names = df_pc9_count_matrix[:, 0].to_list() # Extract gene names (=first column)
df_pc9_count_matrix_without_genenames = df_pc9_count_matrix[:, 1:] # Exclude first column which containes the gene names

cell_names = df_pc9_count_matrix_without_genenames.columns # Extract names of the cells

numpy_count_matrix = df_pc9_count_matrix_without_genenames.to_numpy()  # Convert to a numpy matrix to enable conversion to AnnData object

# Create AnnData object
adata = ad.AnnData(X=numpy_count_matrix.T,
                   var=pd.DataFrame(index=gene_names),
                   obs=pd.DataFrame(index=cell_names))

In [None]:
adata

So, the number of cells = 56419 and the number of genes = 22166

##### Enter relevant metadata to the AnnData object

In [None]:
# Enter relavant metadata to the AnnData object

# Get lineage barcode in adata object
adata.obs['lineage_barcode']=df_metaData_with_lineage['lineage_barcode'] # lineage barcodes from metadata of GEO

# Get time points as categorical in adata object
time_points_cat = df_metaData_with_lineage.time_point.astype('category') # convert dtype from int64 to category (for plotting lateron)
adata.obs['time_point'] = time_points_cat # add categorical time points to adata object
adata

# Get sample types as categorical in adata object (= time points for cells from day 0 - 7 and cell fate categories for day-14 cells)
sample_type_cat = copy_df.sample_type.astype('category') # convert dtype to category (for plotting lateron)
adata.obs['sample_type'] = sample_type_cat # add categorical sample type to adata object

# Get majority fate of the lineages in adata object
majority_fate_cat = df_metaData_with_lineage.majority_fate.astype('category') # convert dtype from int64 to category (for plotting lateron)
adata.obs['majority_fate'] = majority_fate_cat # add categorical time points to adata object

adata

##### Filter non-zero count genes

In [None]:
# Store the original number of cells and genes
number_cells_before_filtering = adata.n_obs
number_genes_before_filtering = adata.n_vars

# Only consider cells with more than X genes --> not applied for now
# sc.pp.filter_cells(adata, min_genes=200)

# Only consider genes with more than 1 count
sc.pp.filter_genes(adata, min_counts=1)

# Print filtering results
print('Filtered out {} cells that have less than the minimum amount of genes expressed'.format(number_cells_before_filtering-adata.n_obs),'--> No filter on the cells to have a minimum amount of genes detected','\n',
      'Filtered out {} genes that are detected in less than 1 cell'.format(number_genes_before_filtering-adata.n_vars))

Apparently there were no zero-count genes

### 3. Normalization

In [None]:
# Normalize gene expression matrix with total UMI count per cell
adata.X = adata.X.astype('float64') # Convert the main data matrix to float64, because normalization was not possible with int64 values
sc.pp.normalize_per_cell(adata, key_n_counts='n_counts_all')

### 4. Identification of highly variable genes

Removing non-variable genes reduces the calculation time during the GRN reconstruction and simulation steps. It also improves the overall accuracy of the GRN inference by removing noisy genes. Using the top 2000~3000 variable genes is recommended.

In [None]:
# Select top 2000 highly-variable genes
filter_result = sc.pp.filter_genes_dispersion(adata.X,
                                              flavor='cell_ranger',
                                              n_top_genes=2000,
                                              log=False)

# Subset the genes
adata = adata[:, filter_result.gene_subset]

# Renormalize after filtering - making the total expression per cell equal across the dataset
sc.pp.normalize_per_cell(adata)

### 5. Log transformation

In [None]:
# keep raw count data before log transformation
adata.raw = adata
adata.layers["raw_count"] = adata.raw.X.copy()

# Log transformation 
sc.pp.log1p(adata) # The "log1p" function means taking the natural logarithm of (1 + X) for each value in the expression matrix, the addition of 1 ensures all values, including zeros, are log-transformed without creating NaN values

# Keep log_transformed data before scaling
adata.layers["log_transformed"] = adata.X.copy()

# Scaling 
sc.pp.scale(adata)

adata

In [None]:
adata.obs

### 6. PCA and neighbor calculations

In [None]:
# PCA
sc.tl.pca(adata, svd_solver='arpack')

# Diffusion map
sc.pp.neighbors(adata, n_neighbors=4, n_pcs=20)
sc.tl.diffmap(adata)

# Calculate neihbors again based on diffusionmap
sc.pp.neighbors(adata, n_neighbors=10, use_rep='X_diffmap')

Diffusion map is applied to denoise the graph.

### 7. Cell clustering

In [None]:
# Run Louvain clustering
sc.tl.louvain(adata)

### 8. Dimensionality reduction using PAGA and force-directed graphs as well as UMAP

In [None]:
# PAGA graph construction
sc.tl.paga(adata, groups='louvain')

In [None]:
plt.rcParams["figure.figsize"] = [6, 4.5]
sc.pl.paga(adata)

In [None]:
# Calculate force-directed graph with PAGA graph as initial cluster position
sc.tl.draw_graph(adata, init_pos='paga', random_state=123) # Random seed to ensure consistency of plot for different runs

In [None]:
# Calculate UMAP 
sc.tl.umap(adata,random_state=123)

### 9. Visualization

In [None]:
sample_type_palette = {
    '0': '#F560A6',  # Pink
    '3': '#91307F',  # Purple
    '7': '#2D0059',  # Dark purple
    'Cycling': '#1f77b4',  # Blue
    'Moderate_cyclers': '#ff7f0e',  # Orange
    'Non-cycling': '#2ca02c',  # Green
}

sample_type_palette_time = {
    0: '#F560A6',  # Pink
    3: '#91307F',  # Purple
    7: '#2D0059',  # Dark purple
    14: '#5b5b5b',  # Grey
}

In [None]:
# Plot force-directed graph with PAGA graph as initial cluster position - legend next to plot
fig, axes = plt.subplots(1, 3, figsize=(20, 5))
sc.pl.draw_graph(adata, color='louvain', legend_loc='on data', ax=axes[0], show=False)                                         # First plot with legend on data
sc.pl.draw_graph(adata, color='time_point', ax=axes[0], palette=sample_type_palette_time, show=False)                          # Second plot 
sc.pl.draw_graph(adata, color='sample_type', ax=axes[1], palette=sample_type_palette, show=False)                              # Third plot                     # Third plot 

# Save the combined plot
plt.tight_layout()
# plt.savefig("/home/jolien/Notebooks/data_preprocessing/figures/PAGA_all_groupings.png")

In [None]:
# Plot force-directed graph with PAGA graph as initial cluster position - colored by majority fate of the lineage of the cell
# sc.pl.draw_graph(adata, color=["majority_fate"], save="_PAGA_majority_fate.png")

In [None]:
# UMAP plot
fig, axes = plt.subplots(1, 3, figsize=(20, 5))
sc.pl.umap(adata, color='louvain', legend_loc='on data', ax=axes[0], show=False)    # First plot with legend on data
sc.pl.umap(adata, color='time_point', ax=axes[1], palette=sample_type_palette_time, show=False)                       # Second plot 
sc.pl.umap(adata, color='sample_type', ax=axes[2], palette=sample_type_palette, show=False)                      # Third plot 


# Save the combined plot
plt.tight_layout()
# plt.savefig("/home/jolien/Notebooks/data_preprocessing/figures/UMAP_all_groupings.png")

### 10. Save AnnData object

In [None]:
# adata.write('/home/jolien/Notebooks/data/preprocessed_data_v2.h5ad')