In [1]:
import scanpy as sc
import anndata as ad
import scib
import numpy as np
import pandas as pd

In [2]:
%run ./custom_silhouette_functions.ipynb

[0;31mSignature:[0m
[0msilhouette_samples_custom[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mX[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mlabels[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmetric[0m[0;34m=[0m[0;34m'euclidean'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mbetween_cluster_distances[0m[0;34m=[0m[0;34m'nearest'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Compute the average silhouette score for the dataset X with the given labels.

Parameters:
X : array-like, shape (n_samples, n_features)
    Feature array.
labels : array-like, shape (n_samples,)
    Labels of each point.
    
metric : metric for distance calculation, default:"euclidean", alternatives, e.g., "cosine"

between_cluster_distances: one out of "mean_other", "furthest", "nearest"


Returns:
score : float
    The average silhouette score.
[0;31mFile:[0m      /tmp/7428112.1.all.q/ipykernel_1258377/4094074416.py
[0;31mType:[0m      functio

In [3]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)


In [4]:
scenarios = ['Overcorrected', 'None', 'Mild', 'Intermediate', 'Strong']

In [5]:
np.random.seed(61)

# Collect computed scores, nested dict is simple to convert to pd.DataFrame
score_dict = {}
for scenario in scenarios:
    # Initialize nested dict
    score_dict[scenario] = {}
    
    adata = ad.AnnData(X=pd.read_csv('data/simulated/counts_matrix_{}.csv'.format(scenario), index_col=None).values.T, obs=pd.read_csv('data/simulated/cell_metadata_{}.csv'.format(scenario), index_col=0))
    adata.obs['Batch'] = adata.obs['Batch'].astype('category')
    adata.obs['Cell_type'] = adata.obs['Cell_type'].astype('category')

    # Normalizing to median total counts
    sc.pp.normalize_total(adata)
    # Logarithmize the data
    sc.pp.log1p(adata)
    
    # Create embedding (PCA space) ['X_pca']
    sc.tl.pca(adata)
    sc.pp.neighbors(adata, use_rep='X_pca')
    sc.tl.umap(adata)

    
    # Compute scores
    ## Level of evaluation: batch/sample
    
    ### asw_batch
    score = scib.me.silhouette_batch(
        adata,
        batch_key='Batch',
        group_key='Cell_type',
        embed='X_pca',
        verbose=False
    )
    score_dict[scenario]['asw_batch'] = score
    
    score = scib.me.silhouette_batch(
        adata,
        batch_key='Batch',
        group_key='Cell_type',
        embed='X_pca',
        metric='cosine',
        verbose=False
    )
    score_dict[scenario]['asw_batch_cosine'] = score
    
    
    ### asw_batch_mean_other
    score = silhouette_batch_custom(
        adata,
        batch_key='Batch',
        group_key='Cell_type',
        embed='X_pca',
        between_cluster_distances='mean_other',
        verbose=False
    )
    score_dict[scenario]['asw_batch_mean_other'] = score
    
    score = silhouette_batch_custom(
        adata,
        batch_key='Batch',
        group_key='Cell_type',
        embed='X_pca',
        between_cluster_distances='mean_other',
        metric='cosine',
        verbose=False
    )
    score_dict[scenario]['asw_batch_mean_other_cosine'] = score
    
    ### asw_batch_furthest
    score = silhouette_batch_custom(
        adata,
        batch_key='Batch',
        group_key='Cell_type',
        embed='X_pca',
        between_cluster_distances='furthest',
        verbose=False
    )
    score_dict[scenario]['asw_batch_furthest'] = score
    
    score = silhouette_batch_custom(
        adata,
        batch_key='Batch',
        group_key='Cell_type',
        embed='X_pca',
        between_cluster_distances='furthest',
        metric='cosine',
        verbose=False
    )
    score_dict[scenario]['asw_batch_furthest_cosine'] = score
    
    ### graph iLISI and cLISI on variable batch
    score_dict[scenario]['iLISI_batch'], score_dict[scenario]['cLISI_full'] =  scib.me.lisi.lisi_graph(adata, batch_key='Batch', label_key='Cell_type', type_='knn')
        
    means = []
    total = 0
    for cell_type in adata.obs['Cell_type'].unique():
        tmp_adata = adata[adata.obs['Cell_type']==cell_type]
        cell_type_iLISI = scib.metrics.ilisi_graph(tmp_adata, batch_key='Batch', type_='knn')
        means += [cell_type_iLISI * tmp_adata.shape[0]]
        total += tmp_adata.shape[0]
        print(cell_type, cell_type_iLISI)
    print(means)
    print(np.nansum(means)/total)
    score_dict[scenario]['CiLISI_batch'] = np.nansum(means)/total
    
    ### asw_label
    score = scib.me.silhouette(
        adata,
        group_key='Cell_type',
        embed='X_pca',
    )
    score_dict[scenario]['asw_label'] = score
    
    score = scib.me.silhouette(
        adata,
        group_key='Cell_type',
        embed='X_pca',
        metric='cosine'
    )
    score_dict[scenario]['asw_label_cosine'] = score
    
    
    
    ### nmi    
    scib.metrics.cluster_optimal_resolution(
        adata,
        label_key='Cell_type',
        cluster_key='cluster',
        metric=scib.me.nmi
    )
    
    score = scib.me.nmi(
        adata,
        group1='cluster',
        group2='Cell_type'
    )
    
    score_dict[scenario]['nmi'] = score
    
    ### ari
    scib.metrics.cluster_optimal_resolution(
        adata,
        label_key='Cell_type',
        cluster_key='cluster',
        metric=scib.me.ari
    )
    
    score = scib.me.ari(adata, cluster_key="cluster", label_key="Cell_type")
    score_dict[scenario]['ari'] = score

  from .autonotebook import tqdm as notebook_tqdm


Cell_type3 0.7821308420568993
Cell_type2 0.768738282569522
Chunk 206 does not have enough neighbors. Skipping...
Chunk 614 does not have enough neighbors. Skipping...
Cell_type1 0.773300063272377
[1623.703628110123, 868.6742593035599, 614.0002502382673]
0.7765945344129875
resolution: 0.1, nmi: 0.0
resolution: 0.2, nmi: 0.0
resolution: 0.3, nmi: 0.0
resolution: 0.4, nmi: 0.0
resolution: 0.5, nmi: 0.0
resolution: 0.6, nmi: 0.0
resolution: 0.7, nmi: 0.0010376682558042737
resolution: 0.8, nmi: 0.0013966867245827038
resolution: 0.9, nmi: 0.0013655904797081576
resolution: 1.0, nmi: 0.0033183496204888096
resolution: 1.1, nmi: 0.0033153487909835465
resolution: 1.2, nmi: 0.00308561121137692
resolution: 1.3, nmi: 0.00294745876059362
resolution: 1.4, nmi: 0.005396429216153936
resolution: 1.5, nmi: 0.003953948834615406
resolution: 1.6, nmi: 0.004777102435483651
resolution: 1.7, nmi: 0.005747913634983694
resolution: 1.8, nmi: 0.005623215188243745
resolution: 1.9, nmi: 0.005613228122150825
resolutio

In [6]:
scores = pd.DataFrame(score_dict)
scores

Unnamed: 0,Overcorrected,None,Mild,Intermediate,Strong
asw_batch,0.99325,0.993038,0.993213,0.993096,0.994865
asw_batch_cosine,0.990154,0.989877,0.989834,0.990327,0.990636
asw_batch_mean_other,0.994599,0.994812,0.982046,0.962877,0.937804
asw_batch_mean_other_cosine,0.992678,0.992166,0.964275,0.925486,0.876157
asw_batch_furthest,0.992805,0.993199,0.970899,0.942532,0.907778
asw_batch_furthest_cosine,0.990485,0.98948,0.943966,0.888737,0.82228
iLISI_batch,0.779334,0.780233,0.700693,0.489355,0.334565
cLISI_full,0.291609,0.721169,0.738715,0.834451,0.787153
CiLISI_batch,0.776595,0.783159,0.706151,0.496775,0.338386
asw_label,0.498388,0.523401,0.523963,0.524827,0.523369


In [7]:
pd.DataFrame(score_dict).to_csv("evaluation/batch_removal_scores.csv", index=True)