# Subset Notebook

Working towards analyzing clusters derived in the cluster notebook so that they can be used to create RAG vectors

In [None]:
import warnings

# import numba
from numba.core.errors import NumbaDeprecationWarning, NumbaPendingDeprecationWarning

warnings.filterwarnings("ignore", category=DeprecationWarning)

warnings.simplefilter("ignore", category=NumbaDeprecationWarning)
warnings.simplefilter("ignore", category=NumbaPendingDeprecationWarning)

In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
import anndata as ad

# import os
from scipy.sparse import csr_matrix
import seaborn as sns
import matplotlib.pyplot as plt

# import celltypist
# from celltypist import models
# import scarches as sca

# import urllib.request

warnings.filterwarnings("ignore", category=pd.errors.PerformanceWarning)

sc.set_figure_params(figsize=(5, 5))  # type: ignore

In [None]:
adata = sc.read_h5ad("../data/subset.h5ad")
adata

In [None]:
def get_highly_variable_genes(adata: ad.AnnData) -> ad.AnnData:
    b = adata.var[adata.var.highly_variable]
    return adata[:, b.index] # type: ignore

hvar = get_highly_variable_genes(adata)

In [None]:
def get_cluster_names(adata: ad.AnnData, criterion="leiden_2") -> list[str]:
    clusters = [
        str(x) for x in sorted([int(cluster) for cluster in adata.obs[criterion].unique()])
    ]
    return clusters
print(get_cluster_names(hvar))

In [None]:
from typing import Any

def partition_clusters(adata: ad.AnnData, criterion="leiden_2") -> dict[str, ad.AnnData]:
    clusters = get_cluster_names(adata, criterion)
    cluster_table: dict[str, Any] = {}
    for cluster in clusters:
        subset = adata[adata.obs[criterion] == cluster] # type: ignore
        cluster_table[cluster] = subset.copy()
    return cluster_table

cluster_table = partition_clusters(hvar)
print(f"Length of cluster table: {len(cluster_table)}")
assert isinstance(cluster_table[list(cluster_table.keys())[0]], ad.AnnData) # ensure that we are dealing with copies, not slices

In [None]:
def calculate_highest_frequency_genes(adata: ad.AnnData, number_of_genes:int = 20, expression_threshold=0.0, verbose=False) -> list[str]:
    cell_count, gene_count = adata.shape
    if verbose:
        print(f"{cell_count} cells, {gene_count} genes")
    gene_table = {}
    
    for cell_no in range(cell_count):
        b = adata.X[cell_no] > expression_threshold # type: ignore
        genes = adata.var.index[b]
        for gene in genes:
            if gene in gene_table:
                gene_table[gene] += 1
            else:
                gene_table[gene] = 1
    gene_table = dict(sorted(gene_table.items(), key=lambda x:x[1], reverse=True))

    gene_list = list(gene_table.keys())[0:number_of_genes]
    if verbose:
        for gene in gene_list:
            print(f"{gene}:{gene_table[gene]} ({gene_table[gene]/cell_count*100:4.1f}%)")
    return gene_list
            

In [None]:
def calculate_gene_signature_per_cluster(cluster_table: dict[str, ad.AnnData], 
                                         genes_per_cluster=25,
                                         repeat_limit=5,
                                         expression_threshold=0.0
                                        )-> dict[str,list[str]]:
    gene_dict = {}
    for cluster in cluster_table:
        cdata = cluster_table[cluster]
        gene_list = calculate_highest_frequency_genes(
                                                    adata=cdata, 
                                                    number_of_genes=genes_per_cluster, 
                                                    expression_threshold=expression_threshold)
        # print(f"Cluster:{cluster}. Genes: {gene_list}")
        for gene in gene_list:
            if gene in gene_dict:
                gene_dict[gene].append(cluster)
            else:
                gene_dict[gene] = [cluster]
        # eliminate genes that are present "everywhere"
        # in the following len(v) represents the number of clusters expressing the gene k
    gene_dict = {k:v for k,v in gene_dict.items() if len(v) < repeat_limit}

    # now calcuate the gene list for each cluster
    cluster_dict = {k:list() for k in cluster_table}
    for gene in gene_dict:
        clusters = gene_dict[gene]
        for cluster in clusters:
            cluster_dict[cluster].append(gene)
    return cluster_dict

cluster_dict = calculate_gene_signature_per_cluster(cluster_table, expression_threshold=1.5)
for cluster in cluster_dict:
    print(f"{cluster}:{cluster_dict[cluster]}")


In [None]:
def find_redundant_genes(adata: ad.AnnData, genes_per_cluster=25, repeat_limit=5, expression_threshold=0.0) -> set[str]:
    # common_genes: set[str] = set()
    gene_dict = {}
    for cluster in cluster_table:
        cluster_adata = cluster_table[cluster]
        gene_list = calculate_highest_frequency_genes(
                                                        adata=cluster_adata, 
                                                        number_of_genes=genes_per_cluster,
                                                        expression_threshold=expression_threshold)
        # record which clusters express each gene
        for gene in gene_list:
            if gene in gene_dict:
                gene_dict[gene].append(cluster)
            else:
                gene_dict[gene] = [cluster]
    # filter out genes that are only present in a few clusters
    gene_dict = {k:v for k,v in gene_dict.items() if len(v) >= repeat_limit}

    # get the resulting list of gene names
    gene_names = list(gene_dict.keys())
    # filter out mitochondrial genes
    # gene_names= list(filter(lambda gene: not gene.lower().startswith('mt-'), gene_names))
    common_genes = set(gene_names)
    
    return common_genes
        

find_redundant_genes(adata, expression_threshold=1.5)

In [None]:
#
# need to extract data for a single cell
#
hv = get_highly_variable_genes(adata)
cell_name = hv.obs.index[4]
print(cell_name)
gene_data = hv.X[4].copy() # type: ignore
print(len(gene_data))



In [None]:
expression_threshold = 1.5
b= gene_data > expression_threshold
expression = gene_data[b]
names = hv.var.index[b]
assert expression.shape == names.shape
redundant = find_redundant_genes(hv, expression_threshold=expression_threshold)
genes = dict(sorted(dict(zip(names,expression)).items(), key=lambda x:x[1], reverse=True))
genes = {k:v for k,v in genes.items() if k not in redundant and not k.startswith("MT-")}
print(len(genes))


In [None]:
genes

In [None]:
def get_gene_signature(adata: ad.AnnData,
                   feature_index: int,
                   expression_threshold=0.0,
                   redundant_genes: set[str]=set(),
                   verbose: bool = False) -> dict[str,float]:
    """Calculate the list of genes for the given cell based on the feature_index.  Only non-mitochondrial
    genes with an expression level greater than the provided expression threshold are reported.

    Args:
        adata (ad.AnnData): The AnnData containing the cells.  Assume that only highly variable genes have been provided.
        feature_index (int): The index of the cell to be measured
        expression_threshold (float, optional): The minimum expression level. Defaults to 0.0, which returns all non-zero genes.
        redundant_genes (set[str], optional): Genes to be filtered out from the final result. Defaults to the empty set.
        verbose (bool, optional): Print intermediated data to standard output. Defaults to False.
    Returns:
        dict[str, float]: Dictionary with gene names as keys and expression as values.
    """
    num_cells, _ = adata.shape
    if feature_index < 0 or feature_index > num_cells-1:
        raise ValueError(f"Feature index ({feature_index}) outside the range of cells (0..{num_cells-1}) in the current dataset.")
    gene_data = adata.X[feature_index].copy() # type: ignore
    if verbose:
        print(f"Started with {len(gene_data)} genes")
    # calculate the mask to find gene subset
    b= gene_data > expression_threshold
    expression = gene_data[b]
    names = adata.var.index[b] 
    assert expression.shape == names.shape
    if verbose:
        print(f"Found {len(names)} genes exceeding expression threshold.")
    
    # sort the genes based on expression
    genes = dict(sorted(dict(zip(names,expression)).items(), key=lambda x:x[1], reverse=True))
    # remove redundant and mitochondrial genes
    genes = {k:v for k,v in genes.items() if k not in redundant_genes and not k.startswith("MT-")}
    if verbose:
        print(f"Found {len(genes)} genes after filtering.")
    return genes

In [None]:
expression_threshold = 1.5
redundant = find_redundant_genes(hv, expression_threshold=expression_threshold)
sig = get_gene_signature(adata=hv, 
                     feature_index=2, 
                     expression_threshold=expression_threshold, 
                     redundant_genes=redundant,
                     verbose=False)
print(len(sig))
sig

In [None]:
expression_threshold = 1.5
redundant = find_redundant_genes(hv, expression_threshold=expression_threshold)
dataframes :list[pd.DataFrame] = []
total_cells=0
for cluster_name in cluster_table:
    # test with a single cluster
    cluster_signatures = pd.DataFrame(columns = ['cluster','signature'], index=cluster_table[cluster_name].obs.index)
    cluster_adata = cluster_table[cluster_name]
    n_cells, n_genes = cluster_adata.shape
    print(n_cells)
    total_cells+=n_cells
    # for cell_no in range(10):
    for cell_no in range(n_cells):
        cluster_signatures.iloc[cell_no,0] = cluster_name
        cluster_signatures.iloc[cell_no,1] = " ".join(list(get_gene_signature(adata=cluster_adata, 
                                                        feature_index=cell_no,
                                                        expression_threshold=expression_threshold,
                                                        redundant_genes=redundant
                                                        ).keys()))
    # print(cluster_signatures.shape)
    dataframes.append(cluster_signatures)

sigs = pd.concat(dataframes, axis=0)
print(sigs.shape)
print(total_cells)

        

In [None]:
def process_clusters(cluster_table:dict[str,ad.AnnData], redundant_genes:set[str], expression_threshold=0.0, verbose=False) -> pd.DataFrame:
    """Given a dictionary of clusters and a set of redundant genes, calculates the gene_signature on a cell by cell basis.

    Args:
        cluster_table (dict[str,ad.AnnData]): Holds the cluster data with cluster names as keys.
        redundant_genes (set[str]): A set of separately calculated genes to exclude from signatures.
        expression_threshold (float, optional): Only count genes with expression levels greater than this number. Defaults to 0.0.
        verbose (bool, optional): Defaults to False.

    Returns:
        pd.DataFrame: A dataframe with cells as the index and columns for cluster name and signature (as a space separated string).
    """
    dataframes :list[pd.DataFrame] = []
    total_cells=0
    for cluster_name in cluster_table:
        cluster_signatures = pd.DataFrame(columns = ['cluster','signature'], index=cluster_table[cluster_name].obs.index)
        cluster_adata = cluster_table[cluster_name]
        n_cells, _ = cluster_adata.shape
        if verbose:
            print(f"cluster {cluster_name} has {n_cells} cells.")
        total_cells+=n_cells
        for cell_no in range(n_cells):
            cluster_signatures.iloc[cell_no,0] = cluster_name
            cluster_signatures.iloc[cell_no,1] = " ".join(list(get_gene_signature(adata=cluster_adata, 
                                                            feature_index=cell_no,
                                                            expression_threshold=expression_threshold,
                                                            redundant_genes=redundant
                                                            ).keys()))
        if verbose:
            print(cluster_signatures.head())
        dataframes.append(cluster_signatures)

    sigs = pd.concat(dataframes, axis=0)
    assert sigs.shape[0] == total_cells # sanity check to ensure that all cells are being processed
    if verbose:
        print(f"Processed {total_cells} to produce a dataframe with dimensions {sigs.shape}.")
    return sigs

redundant = find_redundant_genes(hv, expression_threshold=expression_threshold)
sigs_pd = process_clusters(cluster_table=cluster_table, redundant_genes=redundant,expression_threshold=1.5) 


In [None]:
b = sigs_pd[sigs_pd.cluster=='0']
b.count()

In [None]:
# write signatures to disk
sigs_pd.to_csv("../data/sigs.csv")