# Subset Notebook

Working towards analyzing clusters derived in the cluster notebook so that they can be used to create RAG vectors

In [67]:
import warnings

# import numba
from numba.core.errors import NumbaDeprecationWarning, NumbaPendingDeprecationWarning

warnings.filterwarnings("ignore", category=DeprecationWarning)

warnings.simplefilter("ignore", category=NumbaDeprecationWarning)
warnings.simplefilter("ignore", category=NumbaPendingDeprecationWarning)

In [68]:
import scanpy as sc
import pandas as pd
import numpy as np
import anndata as ad

# import os
from scipy.sparse import csr_matrix
import seaborn as sns
import matplotlib.pyplot as plt

# import celltypist
# from celltypist import models
# import scarches as sca

# import urllib.request

warnings.filterwarnings("ignore", category=pd.errors.PerformanceWarning)

sc.set_figure_params(figsize=(5, 5))  # type: ignore

In [24]:
adata = sc.read_h5ad("data/subset.h5ad")
adata

AnnData object with n_obs × n_vars = 9370 × 31208
    obs: 'n_genes_by_counts', 'total_counts', 'pct_counts_in_top_20_genes', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'pct_counts_ribo', 'total_counts_hb', 'pct_counts_hb', 'outlier', 'mt_outlier', 'DF_score', 'batch', 'size_factors', 'leiden_2'
    var: 'gene_ids', 'feature_types', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'highly_variable'

In [25]:
def get_highly_variable_genes(adata: ad.AnnData) -> ad.AnnData:
    b = adata.var[adata.var.highly_variable]
    return adata[:, b.index] # type: ignore

hvar = get_highly_variable_genes(adata)

In [26]:
def get_cluster_names(adata: ad.AnnData, criterion="leiden_2") -> list[str]:
    clusters = [
        str(x) for x in sorted([int(cluster) for cluster in adata.obs[criterion].unique()])
    ]
    return clusters
print(get_cluster_names(hvar))

['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18']


In [27]:
from typing import Any

def partition_clusters(adata: ad.AnnData, criterion="leiden_2") -> dict[str, ad.AnnData]:
    clusters = get_cluster_names(adata, criterion)
    cluster_table: dict[str, Any] = {}
    for cluster in clusters:
        subset = adata[adata.obs[criterion] == cluster] # type: ignore
        cluster_table[cluster] = subset.copy()
    return cluster_table

cluster_table = partition_clusters(hvar)
print(f"Length of cluster table: {len(cluster_table)}")
assert isinstance(cluster_table[list(cluster_table.keys())[0]], ad.AnnData) # ensure that we are dealing with copies, not slices

Length of cluster table: 19


In [49]:
def calculate_highest_frequency_genes(adata: ad.AnnData, number_of_genes:int = 20, expression_threshold=0.0, verbose=False) -> list[str]:
    cell_count, gene_count = adata.shape
    if verbose:
        print(f"{cell_count} cells, {gene_count} genes")
    gene_table = {}
    
    for cell_no in range(cell_count):
        b = adata.X[cell_no] > expression_threshold # type: ignore
        genes = adata.var.index[b]
        for gene in genes:
            if gene in gene_table:
                gene_table[gene] += 1
            else:
                gene_table[gene] = 1
    gene_table = dict(sorted(gene_table.items(), key=lambda x:x[1], reverse=True))

    gene_list = list(gene_table.keys())[0:number_of_genes]
    if verbose:
        for gene in gene_list:
            print(f"{gene}:{gene_table[gene]} ({gene_table[gene]/cell_count*100:4.1f}%)")
    return gene_list
            

In [51]:
def calculate_gene_signature_per_cluster(cluster_table: dict[str, ad.AnnData], 
                                         genes_per_cluster=25,
                                         repeat_limit=5,
                                         expression_threshold=0.0
                                        )-> dict[str,list[str]]:
    gene_dict = {}
    for cluster in cluster_table:
        cdata = cluster_table[cluster]
        gene_list = calculate_highest_frequency_genes(
                                                    adata=cdata, 
                                                    number_of_genes=genes_per_cluster, 
                                                    expression_threshold=expression_threshold)
        # print(f"Cluster:{cluster}. Genes: {gene_list}")
        for gene in gene_list:
            if gene in gene_dict:
                gene_dict[gene].append(cluster)
            else:
                gene_dict[gene] = [cluster]
        # eliminate genes that are present "everywhere"
        # in the following len(v) represents the number of clusters expressing the gene k
    gene_dict = {k:v for k,v in gene_dict.items() if len(v) < repeat_limit}

    # now calcuate the gene list for each cluster
    cluster_dict = {k:list() for k in cluster_table}
    for gene in gene_dict:
        clusters = gene_dict[gene]
        for cluster in clusters:
            cluster_dict[cluster].append(gene)
    return cluster_dict

cluster_dict = calculate_gene_signature_per_cluster(cluster_table, expression_threshold=1.5)
for cluster in cluster_dict:
    print(f"{cluster}:{cluster_dict[cluster]}")


0:['GNLY', 'S100A4', 'TXNIP', 'PRKCH', 'DDX5']
1:['TXNIP', 'PRKCH', 'DDX5', 'SYNE2', 'SYNE1', 'FYN', 'RIPOR2', 'PARP8', 'THEMIS']
2:['PARP8', 'CBLB', 'RORA', 'IL7R', 'RPL41', 'PLCB1', 'SKAP1', 'RPL11']
3:['GNLY', 'S100A4', 'TXNIP', 'DDX5', 'RPL41', 'HLA-A']
4:['RIPOR2', 'SKAP1', 'FOXP1', 'PDE3B', 'INPP4B', 'SMCHD1', 'MAML2', 'ATM']
5:['FOXP1', 'SMCHD1', 'BACH2', 'BANK1', 'RALGPS2', 'FCHSD2', 'ADK', 'ZCCHC7', 'CAMK2D', 'LYN', 'ARHGAP24']
6:['RIPOR2', 'FOXP1', 'SMCHD1', 'BACH2', 'BANK1', 'RALGPS2', 'FCHSD2', 'ADK', 'ZCCHC7', 'CAMK2D', 'LYN', 'MEF2C', 'PRKCB']
7:['PRKCH', 'FYN', 'PARP8', 'CBLB', 'SKAP1', 'AOAH', 'CEMIP2', 'ATP8A1', 'CNOT6L']
8:['FOXP1', 'BACH2', 'ZCCHC7', 'CAMK2D', 'LYN', 'PCDH9', 'PDE4D', 'SSBP2', 'TCF4', 'ARID1B', 'EBF1', 'ACSM3', 'TMEM131L']
9:['AOAH', 'NAMPT', 'VCAN', 'NEAT1', 'DPYD', 'ANXA1', 'ARHGAP26', 'PLXDC2', 'FOS', 'QKI', 'SIPA1L1', 'ATP2B1', 'LRMDA', 'LYST', 'HIF1A', 'JMJD1C', 'MED13L', 'VIM', 'STK17B']
10:['RPL41', 'SLC25A21', 'NFIA', 'TFRC', 'RPLP1', 'CD36',

In [79]:
def find_redundant_genes(adata: ad.AnnData, genes_per_cluster=25, repeat_limit=5, expression_threshold=0.0) -> set[str]:
    # common_genes: set[str] = set()
    gene_dict = {}
    for cluster in cluster_table:
        cluster_adata = cluster_table[cluster]
        gene_list = calculate_highest_frequency_genes(
                                                        adata=cluster_adata, 
                                                        number_of_genes=genes_per_cluster,
                                                        expression_threshold=expression_threshold)
        # record which clusters express each gene
        for gene in gene_list:
            if gene in gene_dict:
                gene_dict[gene].append(cluster)
            else:
                gene_dict[gene] = [cluster]
    # filter out genes that are only present in a few clusters
    gene_dict = {k:v for k,v in gene_dict.items() if len(v) >= repeat_limit}

    # get the resulting list of gene names
    gene_names = list(gene_dict.keys())
    # filter out mitochondrial genes
    # gene_names= list(filter(lambda gene: not gene.lower().startswith('mt-'), gene_names))
    common_genes = set(gene_names)
    
    return common_genes
        

find_redundant_genes(adata, expression_threshold=1.5)

{'ACTB',
 'AFF3',
 'ANKRD44',
 'ARHGAP15',
 'B2M',
 'CD74',
 'EEF1A1',
 'HBB',
 'HLA-B',
 'MALAT1',
 'MBNL1',
 'MT-ATP6',
 'MT-CO1',
 'MT-CO2',
 'MT-CO3',
 'MT-CYB',
 'MT-ND4',
 'MT-ND5',
 'PLCG2',
 'PTPRC',
 'RABGAP1L',
 'UTRN',
 'ZBTB20',
 'ZEB2'}

In [64]:
#
# need to extract data for a single cell
#
hv = get_highly_variable_genes(adata)
cell_name = hv.obs.index[4]
print(cell_name)
gene_data = hv.X[4].copy() # type: ignore
print(len(gene_data))



TGACCAAGTAGACAAA
4000


In [65]:
expression_threshold = 1.5
b= gene_data > expression_threshold
expression = gene_data[b]
names = hv.var.index[b]
assert expression.shape == names.shape
redundant = find_redundant_genes(hv, expression_threshold=expression_threshold)
genes = dict(sorted(dict(zip(names,expression)).items(), key=lambda x:x[1], reverse=True))
genes = {k:v for k,v in genes.items() if k not in redundant and not k.startswith("MT-")}
print(len(genes))


15


In [66]:
genes

{'RPL11': 1.8190973713412917,
 'CEP350': 1.8190973713412917,
 'GNLY': 1.8190973713412917,
 'PTPN4': 1.8190973713412917,
 'SMARCA5': 1.8190973713412917,
 'KIAA0825': 1.8190973713412917,
 'ORC5': 1.8190973713412917,
 'SARAF': 1.8190973713412917,
 'PDCD4': 1.8190973713412917,
 'ABLIM1': 1.8190973713412917,
 'FNBP4': 1.8190973713412917,
 'SLC38A1': 1.8190973713412917,
 'ZC3H13': 1.8190973713412917,
 'CTDSPL2': 1.8190973713412917,
 'NF1': 1.8190973713412917}

In [89]:
def get_gene_signature(adata: ad.AnnData,
                   feature_index: int,
                   expression_threshold=0.0,
                   redundant_genes: set[str]=set(),
                   verbose: bool = False) -> dict[str,float]:
    """Calculate the list of genes for the given cell based on the feature_index.  Only non-mitochondrial
    genes with an expression level greater than the provided expression threshold are reported.

    Args:
        adata (ad.AnnData): The AnnData containing the cells.  Assume that only highly variable genes have been provided.
        feature_index (int): The index of the cell to be measured
        expression_threshold (float, optional): The minimum expression level. Defaults to 0.0, which returns all non-zero genes.
        redundant_genes (set[str], optional): Genes to be filtered out from the final result. Defaults to the empty set.
        verbose (bool, optional): Print intermediated data to standard output. Defaults to False.
    Returns:
        dict[str, float]: Dictionary with gene names as keys and expression as values.
    """
    num_cells, _ = adata.shape
    if feature_index < 0 or feature_index > num_cells-1:
        raise ValueError(f"Feature index ({feature_index}) outside the range of cells (0..{num_cells-1}) in the current dataset.")
    gene_data = adata.X[feature_index].copy() # type: ignore
    if verbose:
        print(f"Started with {len(gene_data)} genes")
    # calculate the mask to find gene subset
    b= gene_data > expression_threshold
    expression = gene_data[b]
    names = adata.var.index[b] 
    assert expression.shape == names.shape
    if verbose:
        print(f"Found {len(names)} genes exceeding expression threshold.")
    
    # sort the genes based on expression
    genes = dict(sorted(dict(zip(names,expression)).items(), key=lambda x:x[1], reverse=True))
    # remove redundant and mitochondrial genes
    genes = {k:v for k,v in genes.items() if k not in redundant_genes and not k.startswith("MT-")}
    if verbose:
        print(f"Found {len(genes)} genes after filtering.")
    return genes

In [96]:
expression_threshold = 1.5
redundant = find_redundant_genes(hv, expression_threshold=expression_threshold)
sig = get_gene_signature(adata=hv, 
                     feature_index=2, 
                     expression_threshold=expression_threshold, 
                     redundant_genes=redundant,
                     verbose=False)
print(len(sig))
sig

44


{'BANK1': 2.4271888792164433,
 'ARID1B': 2.4271888792164433,
 'PLEKHA2': 2.3061907112814617,
 'LINC01619': 2.1685103296015886,
 'BBX': 2.0088007503141467,
 'LIX1-AS1': 2.0088007503141467,
 'BACH2': 2.0088007503141467,
 'JAZF1': 2.0088007503141467,
 'JMJD1C': 2.0088007503141467,
 'CHD9': 2.0088007503141467,
 'COBLL1': 1.8186444515907547,
 'ITPR1': 1.8186444515907547,
 'ZSWIM6': 1.8186444515907547,
 'FBXL17': 1.8186444515907547,
 'CDK14': 1.8186444515907547,
 'ADK': 1.8186444515907547,
 'FCHSD2': 1.8186444515907547,
 'PRH1': 1.8186444515907547,
 'ZFAND6': 1.8186444515907547,
 'PRKCB': 1.8186444515907547,
 'ANKRD12': 1.8186444515907547,
 'COP1': 1.5836324776071877,
 'RALGPS2': 1.5836324776071877,
 'USP34': 1.5836324776071877,
 'IWS1': 1.5836324776071877,
 'MGAT5': 1.5836324776071877,
 'SP100': 1.5836324776071877,
 'FOXP1': 1.5836324776071877,
 'ACAP2': 1.5836324776071877,
 'TAPT1': 1.5836324776071877,
 'TMEM131L': 1.5836324776071877,
 'EBF1': 1.5836324776071877,
 'GMDS-DT': 1.583632477607

(9370, 4000)