# Cluster Evaluation (ARI and NMI)
We use ARI and NMI to evaluate embedding clustering performance.

In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
import torch
from sklearn.metrics import adjusted_rand_score, adjusted_mutual_info_score, silhouette_score

In [None]:
dataname = 'Wang'  # Set dataset name

# Read data and create AnnData object
# Load embedding representation data from CSV file, construct file path using dataset name
adata = sc.AnnData(pd.read_csv('../output/'+dataname+'/'+dataname+'_embedding.csv',sep=',',index_col=0))

# Load cell annotation information
label = pd.read_csv('../data/'+dataname+'_cell_anno.csv',sep=',',index_col=0)
adata.obs['cell_type'] = list(label[dataname.lower()+"@colData$cell_type1"])

# Builds neighborhood graph for umap
sc.pp.neighbors(adata,use_rep='X')

# Set resolution parameter range for Louvain clustering
# Will test resolutions from minn to maxn in increments of 0.1
maxn = 2
minn= 0
list_value = []
for x in  range(minn, maxn*10):
    sc.tl.louvain(adata,resolution=x/10.0,random_state=0)
    list_value.append(adjusted_rand_score(adata.obs['cell_type'],adata.obs['louvain']))
sc.tl.louvain(adata,resolution=list_value.index(max(list_value))*0.1,random_state=0)
print(list_value.index(max(list_value))*0.1)
sc.tl.umap(adata)

# Visualize UMAP with both true cell types and Louvain clusters
sc.pl.umap(
    adata,
    color=["cell_type","louvain"],
    wspace = 0.3,
    frameon=False,
    #save = "scE2TM_cluster_"+dataname+".pdf"
)
print("scE2TM  Adjusted_rand_score   "+str(adjusted_rand_score(adata.obs['cell_type'],adata.obs['louvain']))+"   Adjusted_mutual_info_score   "+str(adjusted_mutual_info_score(adata.obs['cell_type'],adata.obs['louvain'])))

# Interpretable Evaluation
We assess single-cell embedding topic model interpretability using the proposed metrics: IP, TC, TD, TQ, ORA<sub>N</sub>, ORA<sub>U</sub>,
ORA<sub>Q</sub>, GSEA<sub>N</sub>, GSEA<sub>U</sub>, and GSEA<sub>Q</sub>.

In [None]:
import pandas as pd
import gseapy as gp
import numpy as np

# Use gseapy to load the KEGG gene set from the GMT file as the background dataset (bg_data) when calculating TC.
kegg = gp.read_gmt(path="../data/msigdb.v2024.1.Hs.symbols.gmt") 
gene_set = []
cell_gene = []
for value in kegg.values():
    gene_set.extend(value)
    cell_gene.append(value)
gene_set = list(set(gene_set))
dicts_gene_index = {}
for index, value in enumerate(gene_set):
    dicts_gene_index[value] = index
bg_data = np.zeros((len(kegg),len(gene_set)))
for index, values in enumerate(kegg.values()):
    for value in values:
        bg_data[index][dicts_gene_index[value]] = 1

In [None]:
data_tg = pd.read_csv('../output/'+dataname+'/'+dataname+'_tg.csv',sep=',',index_col=0) #Load topic-gene matrix
data_exp = pd.read_csv('../data/'+dataname+'_HIGHPRE_5000.csv',sep=',',index_col=0)
data_tg.columns = data_exp.columns #Extract the gene list from the corresponding single-cell data.
data_tg = data_tg.T

## TC

In [None]:
def compute_coherence(doc_gene, topic_gene, N, dicts_gene_tran):
    """
    Compute Topic Coherence (TC) metric for topic models.
    TC measures the semantic coherence of topics based on co-occurrence statistics
    of top genes in the background corpus.
    """
    # print('computing coherence ...')    
    topic_size, gene_size = np.shape(topic_gene)
    doc_size = np.shape(doc_gene)[0]
    # find top genes'index of each topic
    topic_list = []
    for topic_idx in range(topic_size):
        top_gene_idx = np.argpartition(topic_gene[topic_idx, :], -N)[-N:]
        topic_list.append(top_gene_idx)
    #print(topic_list)
    # compute coherence of each topic
    sum_coherence_score = 0.0
    for i in range(topic_size):
        gene_array = topic_list[i]
        sum_score = 0.0
        for n in range(N):
            if gene_array[n] in dicts_gene_tran:
                flag_n = doc_gene[:, dicts_gene_tran[gene_array[n]]] > 0
                p_n = np.sum(flag_n) / doc_size
                for l in range(n + 1, N):
                    if gene_array[l] in dicts_gene_tran:
                        flag_l = doc_gene[:, dicts_gene_tran[gene_array[l]]] > 0
                        p_l = np.sum(flag_l)
                        p_nl = np.sum(flag_n * flag_l)
                        if p_n * p_l * p_nl > 0:
                            p_l = p_l / doc_size
                            p_nl = p_nl / doc_size
                            sum_score += np.log(p_nl / (p_l * p_n)) / -np.log(p_nl)
        sum_coherence_score += sum_score * (2 / (N * N - N))
    sum_coherence_score = sum_coherence_score / topic_size
    return sum_coherence_score

# Create index mapping dictionary
# Maps indices from topic-gene matrix to background corpus matrix indices
dicts_gene_tran = {}
for index, value in enumerate(data_tg.index):
    if value in dicts_gene_index:
        dicts_gene_tran[index] = dicts_gene_index[value]

TC = compute_coherence(bg_data, data_tg.T.values, 10, dicts_gene_tran)

print(f"===>TC_T{10}: {TC:.5f}")

## TD

In [None]:
from sklearn.feature_extraction.text import CountVectorizer
def TD_eva(texts):
    """
    Calculate Topic Distinctiveness (TD) metric.
    TD measures the proportion of unique words that appear in only one topic.
    Higher TD indicates better topic distinctiveness (less overlap between topics).
    """
    K = len(texts)
    T = len(texts[0].split())
    vectorizer = CountVectorizer()
    counter = vectorizer.fit_transform(texts).toarray()
    TF = counter.sum(axis=0)
    print(counter)
    TD = (TF == 1).sum() / (K * T)
    return TD

def ext_topic_genes(beta, vocab, num_top_gene):
    """
    Extract top genes for each topic from topic-gene distribution matrix.
    """
    topic_str_list = list()
    for i, topic_dist in enumerate(beta):
        topic_genes = np.array(vocab)[np.argsort(topic_dist)][:-(num_top_gene + 1):-1]
        topic_str = ' '.join(topic_genes)
        topic_str_list.append(topic_str)
    return topic_str_list

topic_str_list = ext_topic_genes(data_tg.T.values, data_tg.index, 10)
TD = TD_eva(topic_str_list)
print(f"===>TD_T{10}: {TD:.5f}")

# TQ

In [None]:
TQ = TC*TD
print(f"===>TQ_T{10}: {TQ:.5f}")

## IP

In [None]:
from sklearn import metrics
def purity_score(y_true, y_pred):
    # compute contingency matrix (also called confusion matrix)
    contingency_matrix = metrics.cluster.contingency_matrix(y_true, y_pred)
    # return purity
    return np.sum(np.amax(contingency_matrix, axis=0)) / np.sum(contingency_matrix)
print(purity_score(adata.obs['cell_type'],np.argmax(adata.X, axis=1)))

# ORA<sub>N</sub>, ORA<sub>U</sub>, and ORA<sub>Q</sub>

In [None]:
gene_sets = gp.read_gmt(path="../data/c2.all.v2024.1.Hs.symbols.gmt")
for key, values in (gene_sets.items()):
    values = [val.upper() for val in values]
    gene_sets[key] = values

print('------------')
print(len(gene_sets))
print('------------')
topic_str_list = ext_topic_genes(data_tg.T.values, data_tg.index, 10)
all_res = None
gene_set = copy.deepcopy(gene_sets)
for index in range(len(data_tg.iloc[0])):
    flag = False
    data_temp = topic_str_list[index]
    try:
        pre_res = gp.enrich(gene_list=data_temp.split(), # or gene_list=glist
                    gene_sets=gene_set, 
                    outdir=None,
                    verbose=True)
        pre_res.res2d.insert(0,'topic_index',index)
        if index == 0:
            all_res = pre_res.res2d
        else:
            all_res = all_res.append(pre_res.res2d)
    except:
        print('NAN')
        
all_res_table = all_res[all_res["Adjusted P-value"] <=0.01]
print(len(all_res_table))
print(len(set(all_res_table['Term'])))
print(len(set(all_res_table['Term']))/len(all_res_table))

# GSEA<sub>N</sub>, GSEA<sub>U</sub>, and GSEA<sub>Q</sub>

In [None]:
import copy

gene_sets = gp.read_gmt(path="../data/c2.all.v2024.1.Hs.symbols.gmt")
for key, values in (gene_sets.items()):
    values = [val.upper() for val in values]
    gene_sets[key] = values
    
print('------------')
print(len(gene_sets))
print('------------')
all_res = None
gene_set = copy.deepcopy(gene_sets)
for index in range(len(data_tg.iloc[0])):
    data_temp  = data_tg.iloc[:,index]
    try:
        pre_res = gp.prerank(rnk=data_temp, # or rnk = rnk,
                        gene_sets=gene_set,#
                        threads=100,
                        outdir=None, # don't write to disk
                        )
        pre_res.res2d.insert(0,'topic_index',index)
        if index == 0:
            all_res = pre_res.res2d
        else:
            all_res = all_res.append(pre_res.res2d)
    except:
        print('NAN')
        
all_res_table = all_res[all_res["FDR q-val"] <=0.01]
print(len(all_res_table))
print(len(set(all_res_table['Term'])))
print(len(set(all_res_table['Term']))/len(all_res_table))