In [None]:
!pip install scanpy loompy #suggest to install these two packages

Collecting scprep
  Downloading scprep-1.1.0-py3-none-any.whl (104 kB)
[?25l[K     |███▏                            | 10 kB 18.6 MB/s eta 0:00:01[K     |██████▎                         | 20 kB 14.2 MB/s eta 0:00:01[K     |█████████▍                      | 30 kB 15.0 MB/s eta 0:00:01[K     |████████████▌                   | 40 kB 13.6 MB/s eta 0:00:01[K     |███████████████▊                | 51 kB 13.5 MB/s eta 0:00:01[K     |██████████████████▉             | 61 kB 14.8 MB/s eta 0:00:01[K     |██████████████████████          | 71 kB 14.6 MB/s eta 0:00:01[K     |█████████████████████████       | 81 kB 14.9 MB/s eta 0:00:01[K     |████████████████████████████▎   | 92 kB 14.9 MB/s eta 0:00:01[K     |███████████████████████████████▍| 102 kB 15.4 MB/s eta 0:00:01[K     |████████████████████████████████| 104 kB 15.4 MB/s 
[?25hCollecting phate
  Downloading phate-1.0.7-py3-none-any.whl (23 kB)
Collecting magic-impute
  Downloading magic_impute-3.0.0-py3-none-any.whl (1

In [None]:
pip install git+https://github.com/theislab/scib.git

Collecting git+https://github.com/theislab/scib.git
  Cloning https://github.com/theislab/scib.git to /tmp/pip-req-build-v0isbkpp
  Running command git clone -q https://github.com/theislab/scib.git /tmp/pip-req-build-v0isbkpp
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting anndata2ri
  Downloading anndata2ri-1.0.6-py3-none-any.whl (24 kB)
Collecting python-igraph
  Downloading python-igraph-0.9.8.tar.gz (9.5 kB)
Collecting louvain
  Downloading louvain-0.7.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB)
[K     |████████████████████████████████| 1.1 MB 32.9 MB/s 
[?25hCollecting numpy==1.18.1
  Downloading numpy-1.18.1-cp37-cp37m-manylinux1_x86_64.whl (20.1 MB)
[K     |████████████████████████████████| 20.1 MB 1.4 MB/s 
Collecting h5py<3
  Downloading h5py-2.10.0-cp37-cp37m-manylinux1_x86_64.whl (2.9 MB)
[K     |███████████████████████████

In [None]:
import numpy as np
import pandas as pd
import scanpy as sc
import scib

In [None]:
#adata: our output

#adata_sol: benchmark method
def load_data(pred, solu):
    '''adata: the predicted embeddings
       adata_sol: the provided solution
    '''
    adata = ad.read_h5ad(pred) 
    adata_sol = ad.read_h5ad(solu) 
    adata.obs['batch'] = adata_sol.obs['batch'][adata.obs_names]
    adata.obs['cell_type'] = adata_sol.obs['cell_type'][adata.obs_names]

    adata_bc = adata.obs_names
    adata_sol_bc = adata_sol.obs_names
    select = [item in adata_bc for item in adata_sol_bc]
    adata_sol = adata_sol[select, :]
    print(adata.shape, adata_sol.shape)
    return adata, adata_sol

#nmi rate
def get_nmi(adata):
    print('Preprocessing')
    sc.pp.neighbors(adata, use_rep='X_emb')
    print('Clustering')
    scib.cl.opt_louvain(
        adata,
        label_key='cell_type',
        cluster_key='cluster',
        plot=False,
        inplace=True,
        force=True
    )
    print('Compute score')
    score = scib.me.nmi(adata, group1='cluster', group2='cell_type')
    return score

#cell type asw
def get_cell_type_ASW(adata):
    return scib.me.silhouette(adata, group_key='cell_type', embed='X_emb')

#cell cycle
def get_cell_cycle_conservation(adata, adata_solution):
    recompute_cc = 'S_score' not in adata_solution.obs_keys() or \
            'G2M_score' not in adata_solution.obs_keys()
    organism = adata_solution.uns['organism']
    print('Compute score')
    score = scib.me.cell_cycle(
        adata_pre=adata_solution,
        adata_post=adata,
        batch_key='batch',
        embed='X_emb',
        recompute_cc=recompute_cc,
        organism=organism
    )
    return score

#traj_conservation
def get_traj_conservation(adata, adata_solution):
    adt_atac_trajectory = 'pseudotime_order_ATAC' if 'pseudotime_order_ATAC' in adata_solution.obs else 'pseudotime_order_ADT'
    sc.pp.neighbors(adata, use_rep='X_emb')
    obs_keys = adata_solution.obs_keys()
    if 'pseudotime_order_GEX' in obs_keys:
        score_rna = scib.me.trajectory_conservation(
            adata_pre=adata_solution,
            adata_post=adata,
            label_key='cell_type',
            pseudotime_key='pseudotime_order_GEX'
        )
    else:
        score_rna = np.nan

    if adt_atac_trajectory in obs_keys:
        score_adt_atac = scib.me.trajectory_conservation(
            adata_pre=adata_solution,
            adata_post=adata,
            label_key='cell_type',
            pseudotime_key=adt_atac_trajectory
        )
    else:
        score_adt_atac = np.nan

    score_mean = (score_rna + score_adt_atac) / 2
    return score_mean

#batch asw
def get_batch_ASW(adata):
    score = scib.me.silhouette_batch(
        adata,
        batch_key='batch',
        group_key='cell_type',
        embed='X_emb',
        verbose=False
    )
    return score

#graph connectivity
def get_graph_connectivity(adata):
    sc.pp.neighbors(adata, use_rep='X_emb')
    print('Compute score')
    score = scib.me.graph_connectivity(adata, label_key='cell_type')
    return score

In [None]:
if __name__ == "__main__":
    adata, adata_sol = load_data('/content/drive/MyDrive/experimentdata/openproblems_bmmc_cite_phase1v2.method.output.h5ad','/content/drive/MyDrive/experimentdata/openproblems_bmmc_cite_phase1v2.censor_dataset.output_solution.h5ad')
    #can modify the path
    
    adata.obsm['X_emb'] = adata.X
    #print(adata.obs['cell_type'],adata.obs['batch'] )
    nmi = get_nmi(adata)
    cell_type_asw = get_cell_type_ASW(adata)
    cc_con = get_cell_cycle_conservation(adata, adata_sol)
    traj_con = get_traj_conservation(adata, adata_sol)
    batch_asw = get_batch_ASW(adata)
    graph_score = get_graph_connectivity(adata)
    print('cell type rate')
    print('nmi:',nmi, '    celltype asw:',cell_type_asw, '       cell cycle:',cc_con, '          traj:',traj_con)
    print('batch rate')
    print('batch asw:',batch_asw,'   graph connectivity score:',graph_score)
    print('average metric: %.5f'%np.mean([nmi, cell_type_asw, cc_con, traj_con, batch_asw, graph_score]))

(66175, 50) (66175, 13953)
Preprocessing
Clustering
Clustering...
resolution: 0.1, nmi: 0.5358980224510017
resolution: 0.2, nmi: 0.5744486801143102
resolution: 0.3, nmi: 0.6135683266708654
resolution: 0.4, nmi: 0.622131294297788
resolution: 0.5, nmi: 0.6213551010781458
resolution: 0.6, nmi: 0.6237748193101142
resolution: 0.7, nmi: 0.6345320009023363
resolution: 0.8, nmi: 0.6369256275724962
resolution: 0.9, nmi: 0.6336794377685178
resolution: 1.0, nmi: 0.6450626495798543
resolution: 1.1, nmi: 0.6409919898382503
resolution: 1.2, nmi: 0.6437226807347437
resolution: 1.3, nmi: 0.6452994205855423
resolution: 1.4, nmi: 0.6463839550450831
resolution: 1.5, nmi: 0.6462957311744434
resolution: 1.6, nmi: 0.639478405425506
resolution: 1.7, nmi: 0.6449422496855712
resolution: 1.8, nmi: 0.645966169599072
resolution: 1.9, nmi: 0.6481732991855541
resolution: 2.0, nmi: 0.6392839501336913
optimised clustering against cell_type
optimal cluster resolution: 1.9
optimal score: 0.6481732991855541
Compute scor

Trying to set attribute `.obs` of view, copying.
Trying to set attribute `.obs` of view, copying.
Trying to set attribute `.obs` of view, copying.
Trying to set attribute `.obs` of view, copying.
Trying to set attribute `.obs` of view, copying.
Trying to set attribute `.obs` of view, copying.
Trying to set attribute `.obs` of view, copying.
Trying to set attribute `.obs` of view, copying.
Trying to set attribute `.obs` of view, copying.
Trying to set attribute `.obs` of view, copying.
Trying to set attribute `.obs` of view, copying.


Compute score
cell type rate
nmi: 0.6481732991855541     celltype asw: 0.5301379766315222        cell cycle: 0.9291062219729463           traj: 0.7921608206728267
batch rate
batch asw: 0.7755378427109703    graph connectivity score: 0.8608877205680023
average metric: 0.75600
