In [None]:
import json
import scanpy as sc
import muon as mu
import anndata as ad
import numpy as np
import warnings
import logging
import pandas as pd
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import balanced_accuracy_score
import os
import scib_metrics
from scib_metrics.nearest_neighbors import NeighborsResults
import matplotlib.pyplot as plt

def most_frequent(arr):
    values, counts = np.unique(arr, return_counts=True)
    return values[np.argmax(counts)]
    
logging.getLogger().setLevel(logging.WARNING)
warnings.filterwarnings('ignore')

def label_pred(key, k, mdata, adata):
    likelihoods = mdata.mod['human'].obsm['nlog_likeli_nns_aligned_latent_space']
    coarse_labels = mdata.mod['mouse'].obs['cell_type_coarse'].to_numpy()
    coarse_labels_pred = np.stack([most_frequent(coarse_labels[mdata.mod['human'].obsm['ind_nns_aligned_latent_space'][i][np.argsort(likelihoods[i])]][:25]) for i in range(mdata.mod['human'].n_obs)])
    coarse_labels_pred_hom = np.stack([most_frequent(coarse_labels[mdata.mod['human'].obsm['ind_nns_hom_genes'][i]][:k]) for i in range(mdata.mod['human'].n_obs)])            
    fine_labels = mdata.mod['mouse'].obs['cell_type_fine'].to_numpy()
    fine_labels_pred = np.stack([most_frequent(fine_labels[mdata.mod['human'].obsm['ind_nns_aligned_latent_space'][i][np.argsort(likelihoods[i])]][:25]) for i in range(mdata.mod['human'].n_obs)])
    fine_labels_pred_hom = np.stack([most_frequent(fine_labels[mdata.mod['human'].obsm['ind_nns_hom_genes'][i]][:k]) for i in range(mdata.mod['human'].n_obs)])

    adata.obsm[key] = np.concatenate((mdata.mod['mouse'].obsm['latent_mu'], mdata.mod['human'].obsm['latent_mu']))
    adata.obs[key+'pred_coarse'] = np.concatenate((np.array(['c']*len(coarse_labels)), coarse_labels_pred))
    adata.obs[key+'pred_fine'] = np.concatenate((np.array(['c']*len(fine_labels)), fine_labels_pred))
    adata.obs[key[:-2]+'pred_coarse_hom'] = np.concatenate((np.array(['c']*len(coarse_labels)), coarse_labels_pred_hom))
    adata.obs[key[:-2]+'pred_fine_hom'] = np.concatenate((np.array(['c']*len(fine_labels)), fine_labels_pred_hom))

def calc_neighbors(layer, adata):
    mouse_ind = np.array(adata.obs.system == 0)
    human_ind = np.array(adata.obs.system == 1)

    mouse = adata.obsm[layer][mouse_ind]
    human = adata.obsm[layer][human_ind]
    
    mouse_label_coarse = np.array(adata.obs['cell_type_coarse'][mouse_ind])
    mouse_label_fine = np.array(adata.obs['cell_type_fine'][mouse_ind])
    
    nn = NearestNeighbors(n_neighbors=25, metric='euclidean')
    nn.fit(mouse)
    distances, indices = nn.kneighbors(human)

    def get_mode(labels):
        return pd.Series(labels).mode()[0]  
    
    human_label_coarse = np.array([get_mode(mouse_label_coarse[inds]) for inds in indices])
    human_label_fine = np.array([get_mode(mouse_label_fine[inds]) for inds in indices])

    return human_label_coarse, human_label_fine

def calc_metrics(coarse_labels_true, fine_labels_true, coarse_labels_pred, fine_labels_pred, common_cells_coarse, common_cells_fine):
    transl_dict = {'Kupffer cells': 'KCs'}

    coarse_labels_true = np.array([transl_dict[ct] if ct in transl_dict.keys() else ct for ct in coarse_labels_true])
    fine_labels_true = np.array([transl_dict[ct] if ct in transl_dict.keys() else ct for ct in fine_labels_true])

    metrics = {
        "BAS (Coarse)": balanced_accuracy_score(coarse_labels_true[np.isin(coarse_labels_true, common_cells_coarse)], coarse_labels_pred[np.isin(coarse_labels_true, common_cells_coarse)]),
        "BAS (Fine)": balanced_accuracy_score(fine_labels_true[np.isin(fine_labels_true, common_cells_fine)], fine_labels_pred[np.isin(fine_labels_true, common_cells_fine)]),    
    }
    
    return metrics

class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, (np.integer, np.int32, np.int64)):
            return int(obj)
        elif isinstance(obj, (np.floating, np.float32, np.float64)):
            return float(obj)
        return super().default(obj)

def find_file(filename, data_path):
    file_path = os.path.join(data_path, filename)
    if os.path.exists(file_path):
        return file_path

    print('SKIP: ', filename)
    return None

path = os.path.abspath('').replace('\\', '/')+'/'
data_path = path+'dataset/'
save_path = os.path.abspath('').replace('\\', '/')+'/results/'

In [None]:
eval_size = 50000

for dataset in ['liver_human', "glio", "adipose"]:

    if dataset == "glio" or dataset == "adipose":
        context_key = 'mouse'
        target_key = 'human' 
        load_key = dataset

    elif dataset == "liver_human":
        context_key = 'mouse'
        target_key = 'human'
        load_key = 'liver'          

    load_key = find_file(load_key+".h5mu", data_path)
    mdata = mu.read_h5mu(load_key) 

    mdata[context_key].obs['system'] = 0
    mdata[target_key].obs['system'] = 1

    context_genes = np.array(mdata[context_key].var['human_gene_names'])
    target_genes = np.array(mdata[target_key].var['human_gene_names'])

    ret_vec, ind_a, ind_b =  np.intersect1d(target_genes, context_genes, return_indices=True)
    mdata_target = mdata[target_key][:, ind_a]
    mdata_context = mdata[context_key][:, ind_b]

    mdata_target.var_names = target_genes[ind_a]
    mdata_context.var_names = context_genes[ind_b]

    adata = ad.concat([mdata_context, mdata_target], axis=0, join='inner')

    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)

    sc.tl.pca(adata, n_comps=10)
    adata.obsm["Unintegrated"] = adata.obsm["X_pca"]

    del adata.obsm['batch_label_enc']

    for i in range(10):
        print(i+1, 'of', 10)

        load_key = find_file(dataset+'_embed_sysVI_'+str(i)+'.h5ad', data_path)
        if load_key != None:
            adata.obsm['sysVI_'+str(i)] = ad.read_h5ad(load_key).X
        
        load_key = find_file(dataset+'_unaligned_'+str(i)+'_mdata.h5mu', data_path)
        if load_key != None:
            adata.obsm['scVI_'+str(i)] = mu.read_h5mu(load_key)['train'].obsm['latent_mu']
        
        load_key = find_file("test256_data_"+target_key+"_"+dataset+"_"+context_key+"_"+dataset+"_org_saturn_seed_"+str(i*1234)+".h5ad", data_path)
        if load_key != None:
            adata_saturn = ad.read_h5ad(load_key)
            adata_saturn = ad.concat([adata_saturn[adata_saturn.obs.species == 'mouse'], adata_saturn[adata_saturn.obs.species == 'human']], axis=0)
            sc.tl.pca(adata_saturn, n_comps=10)
            adata.obsm['SATURN_NOANNOT_'+str(i)] = np.array(adata_saturn.X)
            adata.obsm['SATURN_PCA_NOANNOT_'+str(i)] = np.array(adata_saturn.obsm['X_pca'])

        load_key = find_file("test256_data_"+target_key+"_"+dataset+"_"+context_key+"_"+dataset+"_org_saturn_seed_"+str(i*1234)+".h5ad", data_path)
        if load_key != None:
            adata_saturn = ad.read_h5ad(load_key)
            adata_saturn = ad.concat([adata_saturn[adata_saturn.obs.species == 'mouse'], adata_saturn[adata_saturn.obs.species == 'human']], axis=0)
            sc.tl.pca(adata_saturn, n_comps=10)
            adata.obsm['SATURN_ANNOT_'+str(i)] = np.array(adata_saturn.X)
            adata.obsm['SATURN_PCA_ANNOT_'+str(i)] = np.array(adata_saturn.obsm['X_pca'])
        
        load_key = find_file(dataset+'_embed_scArches_'+str(i)+'.h5ad', data_path)
        if load_key != None:
            adata.obsm['scArches_'+str(i)] = ad.read_h5ad(load_key).X
        
        load_key = find_file(dataset+'_embed_scPoli_'+str(i)+'.h5ad', data_path)
        if load_key != None:
            adata.obsm['scPoli_'+str(i)] = ad.read_h5ad(load_key).X
        
        load_key = find_file(load_key+'_'+context_key+'_'+target_key+'_1_inter_'+str(i)+'.h5mu', data_path)
        if load_key != None:
            mdata = mu.read_h5mu(load_key)
            label_pred('scSpecies_1_'+str(i), 1, mdata, adata)
            
        load_key = find_file(load_key+'_'+context_key+'_'+target_key+'_25_inter_'+str(i)+'.h5mu', data_path)
        if load_key != None:
            mdata = mu.read_h5mu(load_key)
            label_pred('scSpecies_25_'+str(i), 25, mdata, adata)
            
        load_key = find_file(load_key+'_'+context_key+'_'+target_key+'_25_latent_'+str(i)+'.h5mu', data_path) 
        if load_key != None:  
            mdata = mu.read_h5mu(load_key)
            label_pred('scSpecies_25_lat_'+str(i), 25, mdata, adata)
        
        load_key = find_file(load_key+'_'+context_key+'_'+target_key+'_250_inter_'+str(i)+'.h5mu', data_path)  
        if load_key != None:                
            mdata = mu.read_h5mu(load_key)
            label_pred('scSpecies_250_'+str(i), 250, mdata, adata)
            
    human_ind = np.array(adata.obs.system == 1)
    mouse_ind = np.array(adata.obs.system == 0)
    transl_dict = {'Kupffer cells': 'KCs'}

    human_label_coarse = np.array(adata.obs['cell_type_coarse'][human_ind])
    human_label_coarse = np.array([transl_dict[ct] if ct in transl_dict.keys() else ct for ct in human_label_coarse])
    human_label_fine = np.array(adata.obs['cell_type_fine'][human_ind])
    mouse_label_coarse = np.array(adata.obs['cell_type_coarse'][mouse_ind])
    mouse_label_coarse = np.array([transl_dict[ct] if ct in transl_dict.keys() else ct for ct in mouse_label_coarse])
    mouse_label_fine = np.array(adata.obs['cell_type_fine'][mouse_ind])
    common_cells_coarse = np.intersect1d(mouse_label_coarse, human_label_coarse)        
    common_cells_fine = np.intersect1d(mouse_label_fine, human_label_fine)

    erg_dict = {}  
    for method in ['Celltypeist', 'kNN_1', 'kNN_25', 'kNN_250']:
        erg_dict[method] = {}  

        if method == 'Celltypeist':
            cell_typeist_fine = pd.read_csv(data_path+dataset+'_cell_type_fine_predictions.csv')
            cell_typeist_coarse = pd.read_csv(data_path+dataset+'_cell_type_fine_predictions.csv')
            fine_labels_pred = np.array(cell_typeist_fine['majority_voting'])[perm[human_ind]]
            coarse_labels_pred = np.array(cell_typeist_coarse['majority_voting'])[perm[human_ind]]
            metrics = calc_metrics(human_label_coarse, human_label_fine, coarse_labels_pred, fine_labels_pred, common_cells_coarse, common_cells_fine)

        elif method == 'kNN_1':
            fine_labels_pred = np.array(adata.obs['scSpecies_1pred_fine_hom'])[human_ind]
            coarse_labels_pred = np.array(adata.obs['scSpecies_1pred_coarse_hom'])[human_ind]
            metrics = calc_metrics(human_label_coarse, human_label_fine, coarse_labels_pred, fine_labels_pred, common_cells_coarse, common_cells_fine)

        elif method == 'kNN_25':
            fine_labels_pred = np.array(adata.obs['scSpecies_25pred_fine_hom'])[human_ind]
            coarse_labels_pred = np.array(adata.obs['scSpecies_25pred_coarse_hom'])[human_ind]
            metrics = calc_metrics(human_label_coarse, human_label_fine, coarse_labels_pred, fine_labels_pred, common_cells_coarse, common_cells_fine)

        elif method == 'kNN_250':
            fine_labels_pred = np.array(adata.obs['scSpecies_250pred_fine_hom'])[human_ind]
            coarse_labels_pred = np.array(adata.obs['scSpecies_250pred_coarse_hom'])[human_ind]
            metrics = calc_metrics(human_label_coarse, human_label_fine, coarse_labels_pred, fine_labels_pred, common_cells_coarse, common_cells_fine)

        erg_dict[method]['PCR (coarse)'] = [np.nan]
        erg_dict[method]['bASW (coarse)'] = [np.nan]
        erg_dict[method]['kBET (coarse)'] = [np.nan]
        erg_dict[method]['GC (coarse)'] = [np.nan]
        erg_dict[method]['Species mixing (coarse)'] = [np.nan]
        
        erg_dict[method]['ARI (coarse)'] = [np.nan]
        erg_dict[method]['NMI (coarse)'] = [np.nan]
        erg_dict[method]['iso (coarse)'] = [np.nan]
        erg_dict[method]['cASW (coarse)'] = [np.nan]
        erg_dict[method]['Biology conservation (coarse)'] = [np.nan]

        erg_dict[method]['Total (coarse)'] = [np.nan]    

        erg_dict[method]['PCR (fine)'] = [np.nan]
        erg_dict[method]['bASW (fine)'] = [np.nan]
        erg_dict[method]['kBET (fine)'] = [np.nan]
        erg_dict[method]['GC (fine)'] = [np.nan]
        erg_dict[method]['Species mixing (fine)'] = [np.nan]

        erg_dict[method]['ARI (fine)'] = [np.nan]
        erg_dict[method]['NMI (fine)'] = [np.nan]
        erg_dict[method]['iso (fine)'] = [np.nan]
        erg_dict[method]['cASW (fine)'] = [np.nan]
        erg_dict[method]['Biology conservation (fine)'] = [np.nan]
        
        erg_dict[method]['Total (fine)'] = [np.nan]    

        erg_dict[method]['BAS Label transfer (coarse)'] = [metrics['BAS (Coarse)']]
        erg_dict[method]['BAS Label transfer (fine)'] = [metrics['BAS (Fine)']]              

    if adata.n_obs > eval_size:
        perm = np.random.permutation(adata.n_obs)[:eval_size]
        adata = adata[perm]  

    Unintegrates_names = ['Unintegrated']
    scVI_names = [key for key in adata.obsm.keys() if 'scVI_' in key]
    sysVI_names = [key for key in adata.obsm.keys() if 'sysVI_' in key]
    SATURN_NOANNOT_names = [key for key in adata.obsm.keys() if 'SATURN_PCA_NOANNOT_' in key]
    SATURN_ANNOT_names = [key for key in adata.obsm.keys() if 'SATURN_PCA_ANNOT_' in key]
    scArches_names = [key for key in adata.obsm.keys() if 'scArches_' in key]
    scPoli_names = [key for key in adata.obsm.keys() if 'scPoli_' in key]
    scSpecies_1_names = [key for key in adata.obsm.keys() if 'scSpecies_1_' in key]
    scSpecies_25_names = [key for key in adata.obsm.keys() if 'scSpecies_25_' in key]
    scSpecies_25_lat_names = [key for key in adata.obsm.keys() if 'scSpecies_25_lat_' in key]
    scSpecies_250_names = [key for key in adata.obsm.keys() if 'scSpecies_250_' in key]
    species_labels = adata.obs['system']
    cell_labels_coarse = adata.obs['cell_type_coarse']
    cell_labels_fine = adata.obs['cell_type_fine']

    human_ind = np.array(adata.obs.system == 1)
    mouse_ind = np.array(adata.obs.system == 0)
    human_label_coarse = np.array(adata.obs['cell_type_coarse'][human_ind])
    human_label_fine = np.array(adata.obs['cell_type_fine'][human_ind])
    mouse_label_coarse = np.array(adata.obs['cell_type_coarse'][mouse_ind])
    mouse_label_fine = np.array(adata.obs['cell_type_fine'][mouse_ind])
    common_cells_coarse = np.intersect1d(mouse_label_coarse, human_label_coarse)        
    common_cells_fine = np.intersect1d(mouse_label_fine, human_label_fine)

    for model_list in [Unintegrates_names, SATURN_NOANNOT_names, SATURN_ANNOT_names, sysVI_names, scVI_names, scArches_names, scPoli_names, scSpecies_1_names, scSpecies_25_names, scSpecies_25_lat_names, scSpecies_250_names]:
        erg_dict[model_list[0][:-2]] = {}
        
        pcr_comparison_coarse_list = []
        silhouette_batch_coarse_list = []
        graph_connectivity_coarse_list = []
        graph_connectivity_common_coarse_list = []
        kbet_coarse_list = []

        ari_leiden_coarse_list = []
        nmi_leiden_coarse_list = []
        isolated_labels_coarse_list = []
        silhouette_label_coarse_list = []

        pcr_comparison_fine_list = []
        silhouette_batch_fine_list = []
        graph_connectivity_fine_list = []
        kbet_fine_list = []

        ari_leiden_fine_list = []
        nmi_leiden_fine_list = []
        isolated_labels_fine_list = []
        silhouette_label_fine_list = []

        bas_coarse_list = []      
        bas_fine_list = []      
        
        for model_key in model_list:
        
            if 'scVI' in model_key or 'sysVI' in model_key or 'SATURN' in model_key or 'scArches' in model_key or 'scPoli' in model_key or 'Unintegrated' in model_key:
                human_label_coarse_pred, human_label_fine_pred = calc_neighbors(model_key, adata)
                metrics = calc_metrics(human_label_coarse, human_label_fine, human_label_coarse_pred, human_label_fine_pred, common_cells_coarse, common_cells_fine)

                bas_coarse_list.append(metrics["BAS (Coarse)"])
                bas_fine_list.append(metrics["BAS (Fine)"])

                    
            elif 'scSpecies' in model_key:
                human_label_coarse_pred = adata.obs[model_key+'pred_coarse'][human_ind]
                human_label_fine_pred = adata.obs[model_key+'pred_fine'][human_ind]
                metrics = calc_metrics(human_label_coarse, human_label_fine, human_label_coarse_pred, human_label_fine_pred, common_cells_coarse, common_cells_fine)

                bas_coarse_list.append(metrics["BAS (Coarse)"])
                bas_fine_list.append(metrics["BAS (Fine)"])

            sc.pp.neighbors(adata, n_neighbors=50, use_rep=model_key) #, key_added=model_key
            distances, indices = scib_metrics.utils.convert_knn_graph_to_idx(adata.obsp.get('distances', adata.obsp['connectivities'])) # adata.obsp.get(model_key+'_distances', adata.obsp[model_key+'_connectivities'])
            neighbors_results_15 = NeighborsResults(distances=distances[:,:15],indices=indices[:,:15])
            neighbors_results_50 = NeighborsResults(distances=distances[:,:50],indices=indices[:,:50])    
            
            pcr = scib_metrics.pcr_comparison(X_pre = adata.obsm['Unintegrated'], 
                                        X_post = np.array(adata.obsm[model_key]), 
                                        covariate=np.array(species_labels))
            
            pcr_comparison_coarse_list.append(pcr)

            silhouette_batch_coarse_list.append(scib_metrics.silhouette_batch(
                                            X = np.array(adata.obsm[model_key]), 
                                            labels = np.array(cell_labels_coarse), 
                                            batch = np.array(species_labels)
                                            )
                                        )

            graph_connectivity_coarse_list.append(scib_metrics.graph_connectivity(
                                             X = neighbors_results_15, 
                                             labels = cell_labels_coarse
                                             )
                                         )

        
            kbet_coarse_list.append(scib_metrics.kbet_per_label(
                                            X = neighbors_results_50, 
                                            batches = species_labels,
                                            labels = cell_labels_coarse
                                            )
                                        )    

            dict = scib_metrics.nmi_ari_cluster_labels_leiden(
                                            X=neighbors_results_15, 
                                            labels=cell_labels_coarse
                                            )
            
            ari_leiden_coarse_list.append([dict['ari']])
            nmi_leiden_coarse_list.append([dict['nmi']])

            isolated_labels_coarse_list.append(scib_metrics.isolated_labels(
                                            X=np.array(adata.obsm[model_key]), 
                                            labels=cell_labels_coarse, 
                                            batch=np.array(species_labels)
                                            )
                                        )

            silhouette_label_coarse_list.append(scib_metrics.silhouette_label(
                                            X=np.array(adata.obsm[model_key]), 
                                            labels=cell_labels_coarse
                                            )
                                        )

            pcr_comparison_fine_list.append(pcr)

            silhouette_batch_fine_list.append(scib_metrics.silhouette_batch(
                                            X = np.array(adata.obsm[model_key]), 
                                            labels = np.array(cell_labels_fine), 
                                            batch = np.array(species_labels)
                                            )
                                        )

            graph_connectivity_fine_list.append(scib_metrics.graph_connectivity(
                                             X = neighbors_results_15, 
                                             labels = cell_labels_fine
                                             )
                                         )

        
            kbet_fine_list.append(scib_metrics.kbet_per_label(
                                            X = neighbors_results_50, 
                                            batches = species_labels,
                                            labels = cell_labels_fine
                                            )
                                        )    

                
            dict = scib_metrics.nmi_ari_cluster_labels_leiden(
                                            X=neighbors_results_15, 
                                            labels=cell_labels_fine
                                            )
            
            ari_leiden_fine_list.append([dict['ari']])
            nmi_leiden_fine_list.append([dict['nmi']])

            isolated_labels_fine_list.append(scib_metrics.isolated_labels(
                                            X=np.array(adata.obsm[model_key]), 
                                            labels=cell_labels_fine, 
                                            batch=np.array(species_labels),
                                            iso_threshold=1,
                                            )
                                        )

            silhouette_label_fine_list.append(scib_metrics.silhouette_label(
                                            X=np.array(adata.obsm[model_key]), 
                                            labels=cell_labels_fine
                                            )
                                        )

        erg_dict[model_list[0][:-2]]['PCR (fine)'] =  np.array(pcr_comparison_fine_list).squeeze()
        erg_dict[model_list[0][:-2]]['bASW (fine)'] =  np.array(silhouette_batch_fine_list).squeeze()
        erg_dict[model_list[0][:-2]]['kBET (fine)'] =  np.array(kbet_fine_list).squeeze()
        erg_dict[model_list[0][:-2]]['GC (fine)'] =  np.array(graph_connectivity_fine_list).squeeze()
        erg_dict[model_list[0][:-2]]['Species mixing (fine)'] = 0.25 * (np.array(pcr_comparison_fine_list).squeeze() + np.array(silhouette_batch_fine_list).squeeze() + np.array(kbet_fine_list).squeeze() + np.array(graph_connectivity_fine_list).squeeze())  
     
        erg_dict[model_list[0][:-2]]['ARI (fine)'] =  np.array(ari_leiden_fine_list).squeeze()
        erg_dict[model_list[0][:-2]]['NMI (fine)'] =  np.array(nmi_leiden_fine_list).squeeze()
        erg_dict[model_list[0][:-2]]['iso (fine)'] =  np.array(isolated_labels_fine_list).squeeze()
        erg_dict[model_list[0][:-2]]['cASW (fine)'] =  np.array(silhouette_label_fine_list).squeeze()      
        erg_dict[model_list[0][:-2]]['Biology conservation (fine)'] =  0.25 * (np.array(ari_leiden_fine_list).squeeze() + np.array(nmi_leiden_fine_list).squeeze() + np.array(isolated_labels_fine_list).squeeze() + np.array(silhouette_label_fine_list).squeeze())

        erg_dict[model_list[0][:-2]]['Total (fine)'] =  0.4 * erg_dict[model_list[0][:-2]]['Species mixing (fine)'] + 0.6 * erg_dict[model_list[0][:-2]]['Biology conservation (fine)']
        erg_dict[model_list[0][:-2]]['BAS Label transfer (fine)'] =  np.array(bas_fine_list).squeeze()

        erg_dict[model_list[0][:-2]]['PCR (coarse)'] =  np.array(pcr_comparison_coarse_list).squeeze()
        erg_dict[model_list[0][:-2]]['bASW (coarse)'] =  np.array(silhouette_batch_coarse_list).squeeze()
        erg_dict[model_list[0][:-2]]['kBET (coarse)'] =  np.array(kbet_coarse_list).squeeze()
        erg_dict[model_list[0][:-2]]['GC (coarse)'] =  np.array(graph_connectivity_coarse_list).squeeze()
        erg_dict[model_list[0][:-2]]['Species mixing (coarse)'] = 0.25 * (np.array(pcr_comparison_coarse_list).squeeze() + np.array(silhouette_batch_coarse_list).squeeze() + np.array(kbet_coarse_list).squeeze() + np.array(graph_connectivity_coarse_list).squeeze())  
     
        erg_dict[model_list[0][:-2]]['ARI (coarse)'] =  np.array(ari_leiden_coarse_list).squeeze()
        erg_dict[model_list[0][:-2]]['NMI (coarse)'] =  np.array(nmi_leiden_coarse_list).squeeze()
        erg_dict[model_list[0][:-2]]['iso (coarse)'] =  np.array(isolated_labels_coarse_list).squeeze()
        erg_dict[model_list[0][:-2]]['cASW (coarse)'] =  np.array(silhouette_label_coarse_list).squeeze()      
        erg_dict[model_list[0][:-2]]['Biology conservation (coarse)'] =  0.25 * (np.array(ari_leiden_coarse_list).squeeze() + np.array(nmi_leiden_coarse_list).squeeze() + np.array(isolated_labels_coarse_list).squeeze() + np.array(silhouette_label_coarse_list).squeeze())

        erg_dict[model_list[0][:-2]]['Total (coarse)'] =  0.4 * erg_dict[model_list[0][:-2]]['Species mixing (coarse)'] + 0.6 * erg_dict[model_list[0][:-2]]['Biology conservation (coarse)']
        erg_dict[model_list[0][:-2]]['BAS Label transfer (coarse)'] =  np.array(bas_coarse_list).squeeze()

        with open(save_path + dataset + "_erg_dict.json", "w") as f:
            json.dump(erg_dict, f, cls=NumpyEncoder)   
        
    mean_dict = {}
    std_dict = {}

    for key, sub_dict in erg_dict.items():
        mean_dict[key] = {metric: np.mean(values) for metric, values in sub_dict.items()}
        std_dict[key] = {metric: np.std(values) for metric, values in sub_dict.items()}

    mean_df = pd.DataFrame(mean_dict).T
    std_df = pd.DataFrame(std_dict).T    

    min_max_df = (mean_df - mean_df.min(axis=0)) / (mean_df - mean_df.min(axis=0)).max(axis=0) 
    mean_df.to_csv(save_path+dataset+"_mean_df.csv", index=False)     
    std_df.to_csv(save_path+dataset+"_std_df.csv", index=False)  
    min_max_df.to_csv(save_path+dataset+"_min_max_df.csv", index=False)  
    

In [None]:
eval_size = 50000

def label_pred(key, k, mdata, adata):
    likelihoods = mdata.mod['mouseNafld'].obsm['nlog_likeli_nns_aligned_latent_space']
    coarse_labels = mdata.mod['mouse'].obs['cell_type_coarse'].to_numpy()
    coarse_labels_pred = np.stack([most_frequent(coarse_labels[mdata.mod['mouseNafld'].obsm['ind_nns_aligned_latent_space'][i][np.argsort(likelihoods[i])]][:25]) for i in range(mdata.mod['mouseNafld'].n_obs)])
    coarse_labels_pred_hom = np.stack([most_frequent(coarse_labels[mdata.mod['mouseNafld'].obsm['ind_nns_hom_genes'][i]][:k]) for i in range(mdata.mod['mouseNafld'].n_obs)])            
    fine_labels = mdata.mod['mouse'].obs['cell_type_fine'].to_numpy()
    fine_labels_pred = np.stack([most_frequent(fine_labels[mdata.mod['mouseNafld'].obsm['ind_nns_aligned_latent_space'][i][np.argsort(likelihoods[i])]][:25]) for i in range(mdata.mod['mouseNafld'].n_obs)])
    fine_labels_pred_hom = np.stack([most_frequent(fine_labels[mdata.mod['mouseNafld'].obsm['ind_nns_hom_genes'][i]][:k]) for i in range(mdata.mod['mouseNafld'].n_obs)])

    adata.obsm[key] = np.concatenate((mdata.mod['mouse'].obsm['latent_mu'], mdata.mod['mouseNafld'].obsm['latent_mu']))
    adata.obs[key+'pred_coarse'] = np.concatenate((np.array(['c']*len(coarse_labels)), coarse_labels_pred))
    adata.obs[key+'pred_fine'] = np.concatenate((np.array(['c']*len(fine_labels)), fine_labels_pred))
    adata.obs[key[:-2]+'pred_coarse_hom'] = np.concatenate((np.array(['c']*len(coarse_labels)), coarse_labels_pred_hom))
    adata.obs[key[:-2]+'pred_fine_hom'] = np.concatenate((np.array(['c']*len(fine_labels)), fine_labels_pred_hom))

for dataset in ["liver_Nafld"]:
    context_key = 'mouse'
    target_key = 'mouseNafld'    
    load_key = 'liver'        

    load_key = find_file(load_key+".h5mu", data_path)
    mdata = mu.read_h5mu(load_key) 

    mdata[context_key].obs['system'] = 0
    mdata[target_key].obs['system'] = 1

    context_genes = np.array(mdata[context_key].var['human_gene_names'])
    target_genes = np.array(mdata[target_key].var['human_gene_names'])

    ret_vec, ind_a, ind_b =  np.intersect1d(target_genes, context_genes, return_indices=True)
    mdata_target = mdata[target_key][:, ind_a]
    mdata_context = mdata[context_key][:, ind_b]

    mdata_target.var_names = target_genes[ind_a]
    mdata_context.var_names = context_genes[ind_b]

    adata = ad.concat([mdata_context, mdata_target], axis=0, join='inner')

    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)

    sc.tl.pca(adata, n_comps=10)
    adata.obsm["Unintegrated"] = adata.obsm["X_pca"]

    del adata.obsm['batch_label_enc']

    for i in range(k):
        print(i+1, 'of', k)

        load_key = find_file(dataset+'_embed_sysVI_'+str(i)+'.h5ad', data_path)
        if load_key != None:
            adata.obsm['sysVI_'+str(i)] = ad.read_h5ad(load_key).X
        
        load_key = find_file(dataset+'_unaligned_'+str(i)+'_mdata.h5mu', data_path)
        if load_key != None:
            adata.obsm['scVI_'+str(i)] = mu.read_h5mu(load_key)['train'].obsm['latent_mu']
        
        load_key = find_file(dataset+'_embed_scArches_'+str(i)+'.h5ad', data_path)
        if load_key != None:
            adata.obsm['scArches_'+str(i)] = ad.read_h5ad(load_key).X
        
        load_key = find_file(dataset+'_embed_scPoli_'+str(i)+'.h5ad', data_path)
        if load_key != None:
            adata.obsm['scPoli_'+str(i)] = ad.read_h5ad(load_key).X
        
        load_key = find_file(load_key+'_'+context_key+'_'+target_key+'_1_inter_'+str(i)+'.h5mu', data_path)
        if load_key != None:
            mdata = mu.read_h5mu(load_key)
            label_pred('scSpecies_1_'+str(i), 1, mdata, adata)
            
        load_key = find_file(load_key+'_'+context_key+'_'+target_key+'_25_inter_'+str(i)+'.h5mu', data_path)
        if load_key != None:
            mdata = mu.read_h5mu(load_key)
            label_pred('scSpecies_25_'+str(i), 25, mdata, adata)
            
        load_key = find_file(load_key+'_'+context_key+'_'+target_key+'_25_latent_'+str(i)+'.h5mu', data_path) 
        if load_key != None:  
            mdata = mu.read_h5mu(load_key)
            label_pred('scSpecies_25_lat_'+str(i), 25, mdata, adata)
        
        load_key = find_file(load_key+'_'+context_key+'_'+target_key+'_250_inter_'+str(i)+'.h5mu', data_path)  
        if load_key != None:        
            mdata = mu.read_h5mu(load_key)
            label_pred('scSpecies_250_'+str(i), 250, mdata, adata)
            
    human_ind = np.array(adata.obs.system == 1)
    mouse_ind = np.array(adata.obs.system == 0)
    transl_dict = {'Kupffer cells': 'KCs'}

    human_label_coarse = np.array(adata.obs['cell_type_coarse'][human_ind])
    human_label_coarse = np.array([transl_dict[ct] if ct in transl_dict.keys() else ct for ct in human_label_coarse])
    human_label_fine = np.array(adata.obs['cell_type_fine'][human_ind])
    mouse_label_coarse = np.array(adata.obs['cell_type_coarse'][mouse_ind])
    mouse_label_coarse = np.array([transl_dict[ct] if ct in transl_dict.keys() else ct for ct in mouse_label_coarse])
    mouse_label_fine = np.array(adata.obs['cell_type_fine'][mouse_ind])
    common_cells_coarse = np.intersect1d(mouse_label_coarse, human_label_coarse)        
    common_cells_fine = np.intersect1d(mouse_label_fine, human_label_fine)

    erg_dict = {}  
    for method in ['Celltypeist', 'kNN_1', 'kNN_25', 'kNN_250']:
        erg_dict[method] = {}  

        if method == 'Celltypeist':
            cell_typeist_fine = pd.read_csv(data_path+dataset+'_cell_type_fine_predictions.csv')
            cell_typeist_coarse = pd.read_csv(data_path+dataset+'_cell_type_coarse_predictions.csv')
            fine_labels_pred = np.array(cell_typeist_fine['majority_voting'])[perm[human_ind]]
            coarse_labels_pred = np.array(cell_typeist_coarse['majority_voting'])[perm[human_ind]]
            metrics = calc_metrics(human_label_coarse, human_label_fine, coarse_labels_pred, fine_labels_pred, common_cells_coarse, common_cells_fine)

        elif method == 'kNN_1':
            fine_labels_pred = np.array(adata.obs['scSpecies_1pred_fine_hom'])[human_ind]
            coarse_labels_pred = np.array(adata.obs['scSpecies_1pred_coarse_hom'])[human_ind]
            metrics = calc_metrics(human_label_coarse, human_label_fine, coarse_labels_pred, fine_labels_pred, common_cells_coarse, common_cells_fine)

        elif method == 'kNN_25':
            fine_labels_pred = np.array(adata.obs['scSpecies_25pred_fine_hom'])[human_ind]
            coarse_labels_pred = np.array(adata.obs['scSpecies_25pred_coarse_hom'])[human_ind]
            metrics = calc_metrics(human_label_coarse, human_label_fine, coarse_labels_pred, fine_labels_pred, common_cells_coarse, common_cells_fine)

        elif method == 'kNN_250':
            fine_labels_pred = np.array(adata.obs['scSpecies_250pred_fine_hom'])[human_ind]
            coarse_labels_pred = np.array(adata.obs['scSpecies_250pred_coarse_hom'])[human_ind]
            metrics = calc_metrics(human_label_coarse, human_label_fine, coarse_labels_pred, fine_labels_pred, common_cells_coarse, common_cells_fine)

        erg_dict[method]['PCR (coarse)'] = [np.nan]
        erg_dict[method]['bASW (coarse)'] = [np.nan]
        erg_dict[method]['kBET (coarse)'] = [np.nan]
        erg_dict[method]['GC (coarse)'] = [np.nan]
        erg_dict[method]['Species mixing (coarse)'] = [np.nan]
        
        erg_dict[method]['ARI (coarse)'] = [np.nan]
        erg_dict[method]['NMI (coarse)'] = [np.nan]
        erg_dict[method]['iso (coarse)'] = [np.nan]
        erg_dict[method]['cASW (coarse)'] = [np.nan]
        erg_dict[method]['Biology conservation (coarse)'] = [np.nan]

        erg_dict[method]['Total (coarse)'] = [np.nan]    

        erg_dict[method]['PCR (fine)'] = [np.nan]
        erg_dict[method]['bASW (fine)'] = [np.nan]
        erg_dict[method]['kBET (fine)'] = [np.nan]
        erg_dict[method]['GC (fine)'] = [np.nan]
        erg_dict[method]['Species mixing (fine)'] = [np.nan]

        erg_dict[method]['ARI (fine)'] = [np.nan]
        erg_dict[method]['NMI (fine)'] = [np.nan]
        erg_dict[method]['iso (fine)'] = [np.nan]
        erg_dict[method]['cASW (fine)'] = [np.nan]
        erg_dict[method]['Biology conservation (fine)'] = [np.nan]
        
        erg_dict[method]['Total (fine)'] = [np.nan]    

        erg_dict[method]['BAS Label transfer (coarse)'] = [metrics['BAS (Coarse)']]
        erg_dict[method]['BAS Label transfer (fine)'] = [metrics['BAS (Fine)']]
  
    if adata.n_obs > eval_size:
        perm = np.random.permutation(adata.n_obs)[:eval_size]
        adata = adata[perm]    
            

    Unintegrates_names = ['Unintegrated']
    scVI_names = [key for key in adata.obsm.keys() if 'scVI_' in key]
    sysVI_names = [key for key in adata.obsm.keys() if 'sysVI_' in key]
    scArches_names = [key for key in adata.obsm.keys() if 'scArches_' in key]
    scPoli_names = [key for key in adata.obsm.keys() if 'scPoli_' in key]
    scSpecies_1_names = [key for key in adata.obsm.keys() if 'scSpecies_1_' in key]
    scSpecies_25_names = [key for key in adata.obsm.keys() if 'scSpecies_25_' in key]
    scSpecies_25_lat_names = [key for key in adata.obsm.keys() if 'scSpecies_25_lat_' in key]
    scSpecies_250_names = [key for key in adata.obsm.keys() if 'scSpecies_250_' in key]
    species_labels = adata.obs['system']
    cell_labels_coarse = adata.obs['cell_type_coarse']
    cell_labels_fine = adata.obs['cell_type_fine']

    human_ind = np.array(adata.obs.system == 1)
    mouse_ind = np.array(adata.obs.system == 0)
    human_label_coarse = np.array(adata.obs['cell_type_coarse'][human_ind])
    human_label_fine = np.array(adata.obs['cell_type_fine'][human_ind])
    mouse_label_coarse = np.array(adata.obs['cell_type_coarse'][mouse_ind])
    mouse_label_fine = np.array(adata.obs['cell_type_fine'][mouse_ind])
    common_cells_coarse = np.intersect1d(mouse_label_coarse, human_label_coarse)        
    common_cells_fine = np.intersect1d(mouse_label_fine, human_label_fine)

    for model_list in [Unintegrates_names, sysVI_names, scVI_names, scArches_names, scPoli_names, scSpecies_1_names, scSpecies_25_names, scSpecies_25_lat_names, scSpecies_250_names]:
        print(model_list[0][:-2])
        erg_dict[model_list[0][:-2]] = {}
        
        pcr_comparison_coarse_list = []
        silhouette_batch_coarse_list = []
        graph_connectivity_coarse_list = []
        graph_connectivity_common_coarse_list = []
        kbet_coarse_list = []

        ari_leiden_coarse_list = []
        nmi_leiden_coarse_list = []
        isolated_labels_coarse_list = []
        silhouette_label_coarse_list = []

        pcr_comparison_fine_list = []
        silhouette_batch_fine_list = []
        graph_connectivity_fine_list = []
        kbet_fine_list = []

        ari_leiden_fine_list = []
        nmi_leiden_fine_list = []
        isolated_labels_fine_list = []
        silhouette_label_fine_list = []

        bas_coarse_list = []      
        bas_fine_list = []      
        
        for model_key in model_list:
        
            if 'scVI' in model_key or 'sysVI' in model_key or 'scArches' in model_key or 'scPoli' in model_key or 'Unintegrated' in model_key:
                human_label_coarse_pred, human_label_fine_pred = calc_neighbors(model_key, adata)
                metrics = calc_metrics(human_label_coarse, human_label_fine, human_label_coarse_pred, human_label_fine_pred, common_cells_coarse, common_cells_fine)

                bas_coarse_list.append(metrics["BAS (Coarse)"])
                bas_fine_list.append(metrics["BAS (Fine)"])

                    
            elif 'scSpecies' in model_key:
                human_label_coarse_pred = adata.obs[model_key+'pred_coarse'][human_ind]
                human_label_fine_pred = adata.obs[model_key+'pred_fine'][human_ind]
                metrics = calc_metrics(human_label_coarse, human_label_fine, human_label_coarse_pred, human_label_fine_pred, common_cells_coarse, common_cells_fine)

                bas_coarse_list.append(metrics["BAS (Coarse)"])
                bas_fine_list.append(metrics["BAS (Fine)"])

            sc.pp.neighbors(adata, n_neighbors=50, use_rep=model_key) #, key_added=model_key
            distances, indices = scib_metrics.utils.convert_knn_graph_to_idx(adata.obsp.get('distances', adata.obsp['connectivities'])) # adata.obsp.get(model_key+'_distances', adata.obsp[model_key+'_connectivities'])
            neighbors_results_15 = NeighborsResults(distances=distances[:,:15],indices=indices[:,:15])
            neighbors_results_50 = NeighborsResults(distances=distances[:,:50],indices=indices[:,:50])    
            
            pcr = scib_metrics.pcr_comparison(X_pre = adata.obsm['Unintegrated'], 
                                        X_post = np.array(adata.obsm[model_key]), 
                                        covariate=np.array(species_labels))
            
            pcr_comparison_coarse_list.append(pcr)

            silhouette_batch_coarse_list.append(scib_metrics.silhouette_batch(
                                            X = np.array(adata.obsm[model_key]), 
                                            labels = np.array(cell_labels_coarse), 
                                            batch = np.array(species_labels)
                                            )
                                        )

            graph_connectivity_coarse_list.append(scib_metrics.graph_connectivity(
                                             X = neighbors_results_15, 
                                             labels = cell_labels_coarse
                                             )
                                         )

        
            kbet_coarse_list.append(scib_metrics.kbet_per_label(
                                            X = neighbors_results_50, 
                                            batches = species_labels,
                                            labels = cell_labels_coarse
                                            )
                                        )    

                
            dict = scib_metrics.nmi_ari_cluster_labels_leiden(
                                            X=neighbors_results_15, 
                                            labels=cell_labels_coarse
                                            )
            
            ari_leiden_coarse_list.append([dict['ari']])
            nmi_leiden_coarse_list.append([dict['nmi']])

            isolated_labels_coarse_list.append(scib_metrics.isolated_labels(
                                            X=np.array(adata.obsm[model_key]), 
                                            labels=cell_labels_coarse, 
                                            batch=np.array(species_labels)
                                            )
                                        )

            silhouette_label_coarse_list.append(scib_metrics.silhouette_label(
                                            X=np.array(adata.obsm[model_key]), 
                                            labels=cell_labels_coarse
                                            )
                                        )

            pcr_comparison_fine_list.append(pcr)

            silhouette_batch_fine_list.append(scib_metrics.silhouette_batch(
                                            X = np.array(adata.obsm[model_key]), 
                                            labels = np.array(cell_labels_fine), 
                                            batch = np.array(species_labels)
                                            )
                                        )

            graph_connectivity_fine_list.append(scib_metrics.graph_connectivity(
                                             X = neighbors_results_15, 
                                             labels = cell_labels_fine
                                             )
                                         )

        
            kbet_fine_list.append(scib_metrics.kbet_per_label(
                                            X = neighbors_results_50, 
                                            batches = species_labels,
                                            labels = cell_labels_fine
                                            )
                                        )    

                
            dict = scib_metrics.nmi_ari_cluster_labels_leiden(
                                            X=neighbors_results_15, 
                                            labels=cell_labels_fine
                                            )
            
            ari_leiden_fine_list.append([dict['ari']])
            nmi_leiden_fine_list.append([dict['nmi']])

            isolated_labels_fine_list.append(scib_metrics.isolated_labels(
                                            X=np.array(adata.obsm[model_key]), 
                                            labels=cell_labels_fine, 
                                            batch=np.array(species_labels),
                                            iso_threshold=1,
                                            )
                                        )

            silhouette_label_fine_list.append(scib_metrics.silhouette_label(
                                            X=np.array(adata.obsm[model_key]), 
                                            labels=cell_labels_fine
                                            )
                                        )

        erg_dict[model_list[0][:-2]]['PCR (fine)'] =  np.array(pcr_comparison_fine_list).squeeze()
        erg_dict[model_list[0][:-2]]['bASW (fine)'] =  np.array(silhouette_batch_fine_list).squeeze()
        erg_dict[model_list[0][:-2]]['kBET (fine)'] =  np.array(kbet_fine_list).squeeze()
        erg_dict[model_list[0][:-2]]['GC (fine)'] =  np.array(graph_connectivity_fine_list).squeeze()
        erg_dict[model_list[0][:-2]]['Species mixing (fine)'] = 0.25 * (np.array(pcr_comparison_fine_list).squeeze() + np.array(silhouette_batch_fine_list).squeeze() + np.array(kbet_fine_list).squeeze() + np.array(graph_connectivity_fine_list).squeeze())  
     
        erg_dict[model_list[0][:-2]]['ARI (fine)'] =  np.array(ari_leiden_fine_list).squeeze()
        erg_dict[model_list[0][:-2]]['NMI (fine)'] =  np.array(nmi_leiden_fine_list).squeeze()
        erg_dict[model_list[0][:-2]]['iso (fine)'] =  np.array(isolated_labels_fine_list).squeeze()
        erg_dict[model_list[0][:-2]]['cASW (fine)'] =  np.array(silhouette_label_fine_list).squeeze()      
        erg_dict[model_list[0][:-2]]['Biology conservation (fine)'] =  0.25 * (np.array(ari_leiden_fine_list).squeeze() + np.array(nmi_leiden_fine_list).squeeze() + np.array(isolated_labels_fine_list).squeeze() + np.array(silhouette_label_fine_list).squeeze())

        erg_dict[model_list[0][:-2]]['Total (fine)'] =  0.4 * erg_dict[model_list[0][:-2]]['Species mixing (fine)'] + 0.6 * erg_dict[model_list[0][:-2]]['Biology conservation (fine)']
        erg_dict[model_list[0][:-2]]['BAS Label transfer (fine)'] =  np.array(bas_fine_list).squeeze()

        erg_dict[model_list[0][:-2]]['PCR (coarse)'] =  np.array(pcr_comparison_coarse_list).squeeze()
        erg_dict[model_list[0][:-2]]['bASW (coarse)'] =  np.array(silhouette_batch_coarse_list).squeeze()
        erg_dict[model_list[0][:-2]]['kBET (coarse)'] =  np.array(kbet_coarse_list).squeeze()
        erg_dict[model_list[0][:-2]]['GC (coarse)'] =  np.array(graph_connectivity_coarse_list).squeeze()
        erg_dict[model_list[0][:-2]]['Species mixing (coarse)'] = 0.25 * (np.array(pcr_comparison_coarse_list).squeeze() + np.array(silhouette_batch_coarse_list).squeeze() + np.array(kbet_coarse_list).squeeze() + np.array(graph_connectivity_coarse_list).squeeze())  
     
        erg_dict[model_list[0][:-2]]['ARI (coarse)'] =  np.array(ari_leiden_coarse_list).squeeze()
        erg_dict[model_list[0][:-2]]['NMI (coarse)'] =  np.array(nmi_leiden_coarse_list).squeeze()
        erg_dict[model_list[0][:-2]]['iso (coarse)'] =  np.array(isolated_labels_coarse_list).squeeze()
        erg_dict[model_list[0][:-2]]['cASW (coarse)'] =  np.array(silhouette_label_coarse_list).squeeze()      
        erg_dict[model_list[0][:-2]]['Biology conservation (coarse)'] =  0.25 * (np.array(ari_leiden_coarse_list).squeeze() + np.array(nmi_leiden_coarse_list).squeeze() + np.array(isolated_labels_coarse_list).squeeze() + np.array(silhouette_label_coarse_list).squeeze())

        erg_dict[model_list[0][:-2]]['Total (coarse)'] =  0.4 * erg_dict[model_list[0][:-2]]['Species mixing (coarse)'] + 0.6 * erg_dict[model_list[0][:-2]]['Biology conservation (coarse)']
        erg_dict[model_list[0][:-2]]['BAS Label transfer (coarse)'] =  np.array(bas_coarse_list).squeeze()

        with open(save_path + dataset + "_erg_dict.json", "w") as f:
            json.dump(erg_dict, f, cls=NumpyEncoder)   
        
    mean_dict = {}
    std_dict = {}

    for key, sub_dict in erg_dict.items():
        mean_dict[key] = {metric: np.mean(values) for metric, values in sub_dict.items()}
        std_dict[key] = {metric: np.std(values) for metric, values in sub_dict.items()}

    mean_df = pd.DataFrame(mean_dict).T
    std_df = pd.DataFrame(std_dict).T    

    min_max_df = (mean_df - mean_df.min(axis=0)) / (mean_df - mean_df.min(axis=0)).max(axis=0) 
    mean_df.to_csv(save_path+dataset+"_mean_df.csv", index=False)     
    std_df.to_csv(save_path+dataset+"_std_df.csv", index=False)  
    min_max_df.to_csv(data_path+dataset+"_min_max_df.csv", index=False)  
    

In [None]:
mean_df_livernafld = pd.read_csv(data_path+"liver_Nafld_mean_df.csv")
std_df_livernafld = pd.read_csv(data_path+"liver_Nafld_std_df.csv")  
mean_df_liver = pd.read_csv(data_path+"liver_human_mean_df.csv")
std_df_liver = pd.read_csv(data_path+"liver_human_std_df.csv")  
mean_df_glio = pd.read_csv(data_path+"glio_mean_df.csv")    
std_df_glio = pd.read_csv(data_path+"glio_std_df.csv")  
mean_df_adipose = pd.read_csv(data_path+"adipose_mean_df.csv")     
std_df_adipose = pd.read_csv(data_path+"adipose_std_df.csv")   

index = ['Celltypeist', 'kNN_1', 'kNN_25', 'kNN_250', 
         'Unintegrat', 'SATURN_PCA_NOANNOT', 'SATURN_PCA_ANNOT', 'sysVI', 'scVI', 'scArches',
    'scPoli', 'scSpecies_1', 'scSpecies_25', 'scSpecies_25_lat', 'scSpecies_250']

mean_df_livernafld.index = index
std_df_livernafld.index = index

mean_df_liver.index = index
std_df_liver.index = index
mean_df_glio.index = index
std_df_glio.index = index
mean_df_adipose.index = index
std_df_adipose.index = index
all_df = (mean_df_liver + mean_df_glio + mean_df_adipose) / 3
all_min_max_df = (all_df - all_df.min(axis=0)) / (all_df.max(axis=0) - all_df.min(axis=0))

# Small data set

In [None]:
erg_dict_small_sysVI = {}

def label_pred_fine(key, mdata, adata, perm_1, perm_2):
    likelihoods = mdata.mod['human'][perm_2].obsm['nlog_likeli_nns_aligned_latent_space']
         
    fine_labels = mdata.mod['mouse'].obs['cell_type_fine'].to_numpy()
    fine_labels_pred = np.stack([most_frequent(fine_labels[mdata.mod['human'].obsm['ind_nns_aligned_latent_space'][perm_2][i][np.argsort(likelihoods[i])]][:25]) for i in range(len(perm_2))])

    adata.obsm[key] = np.concatenate((mdata.mod['mouse'].obsm['latent_mu'][perm_1], mdata.mod['human'].obsm['latent_mu'][perm_2]))
    adata.obs[key+'pred_fine'] = np.concatenate((np.array(['c']*len(perm_1)), fine_labels_pred))

def calc_metrics_fine(fine_labels_true, fine_labels_pred, common_cells_fine):
    transl_dict = {'Kupffer cells': 'KCs'}

    fine_labels_true = np.array([transl_dict[ct] if ct in transl_dict.keys() else ct for ct in fine_labels_true])

    metrics = {
        "BAS (Fine)": balanced_accuracy_score(fine_labels_true[np.isin(fine_labels_true, common_cells_fine)], fine_labels_pred[np.isin(fine_labels_true, common_cells_fine)]),    
    }
    
    return metrics

dataset = 'liver_human'
context_key = 'mouse'
target_key = 'human'
load_key = 'liver'    


for i in [1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000, 20000, 50000, 165680]: # 
    
    j=0
    
    for j in range(10):
        model_key = str(j)+'_'+str(i)
        if i == 165680:
            adata = sc.read_h5ad(data_path+'liver_human_embed_sysVI_{}.h5ad'.format(str(j)))    
        else:
            adata = sc.read_h5ad(data_path+'liver_human_embed_sysVI_{}.h5ad'.format(model_key))    


        mdata = mu.read_h5mu(data_path+'liver.h5mu') 
        ind = np.array([np.where(np.array(mdata.mod['human'].obs_names) == cell)[0][0] for cell in np.array(adata[adata.obs['system'] == 'human'].obs_names)])
        mdata.mod['human'] = mdata.mod['human'][ind]
        mdata.mod['human'].obsm[model_key] = adata.X[adata.obs['system'] == 'human']
        mdata.mod['mouse'].obsm[model_key] = adata.X[adata.obs['system'] == 'mouse']

        mdata[context_key].obs['system'] = 0
        mdata[target_key].obs['system'] = 1

        context_genes = np.array(mdata[context_key].var['human_gene_names'])
        target_genes = np.array(mdata[target_key].var['human_gene_names'])

        ret_vec, ind_a, ind_b =  np.intersect1d(target_genes, context_genes, return_indices=True)
        mdata_target = mdata[target_key][:, ind_a]
        mdata_context = mdata[context_key][:, ind_b]    
        
        perm_1 = np.random.permutation(mdata_context.n_obs)[:min(i*10,165680)]
        perm_2 = np.random.permutation(mdata_target.n_obs)[:min(i,16568)]
    
        mdata[context_key].obs['system'] = 0
        mdata[target_key].obs['system'] = 1

        context_genes = np.array(mdata[context_key].var['human_gene_names'])
        target_genes = np.array(mdata[target_key].var['human_gene_names'])

        ret_vec, ind_a, ind_b =  np.intersect1d(target_genes, context_genes, return_indices=True)
        mdata_target = mdata[target_key][:, ind_a]
        mdata_context = mdata[context_key][:, ind_b]

        mdata_target.var_names = target_genes[ind_a]
        mdata_context.var_names = context_genes[ind_b]

        adata = ad.concat([mdata_context[perm_1], mdata_target[perm_2]], axis=0, join='inner')#ad.concat([mdata_context[perm], mdata_target], axis=0, join='inner')

        sc.pp.normalize_total(adata, target_sum=1e4)
        sc.pp.log1p(adata)

        sc.tl.pca(adata, n_comps=10)
        adata.obsm["Unintegrated"] = adata.obsm["X_pca"]    
            
        human_label_coarse_pred, human_label_fine_pred = calc_neighbors(model_key, adata)    

        human_ind = np.array(adata.obs.system == 1)
        mouse_ind = np.array(adata.obs.system == 0)
        human_label_fine = np.array(adata.obs['cell_type_fine'][human_ind])
        mouse_label_fine = np.array(adata.obs['cell_type_fine'][mouse_ind])       
        common_cells_fine = np.intersect1d(mouse_label_fine, human_label_fine)
        species_labels = adata.obs['system']
        cell_labels_fine = adata.obs['cell_type_fine']
        
        erg_dict_small_sysVI[model_key] = {}

        pcr_comparison_fine_list = []
        silhouette_batch_fine_list = []
        graph_connectivity_fine_list = []

        ari_leiden_fine_list = []
        nmi_leiden_fine_list = []
        isolated_labels_fine_list = []
        silhouette_label_fine_list = []
 
        bas_fine_list = []      

        metrics = calc_metrics_fine(human_label_fine, human_label_fine_pred, common_cells_fine)

        bas_fine_list.append(metrics["BAS (Fine)"])

        sc.pp.neighbors(adata, n_neighbors=50, use_rep=model_key) #, key_added=model_key
        distances, indices = scib_metrics.utils.convert_knn_graph_to_idx(adata.obsp.get('distances', adata.obsp['connectivities'])) # adata.obsp.get(model_key+'_distances', adata.obsp[model_key+'_connectivities'])
        neighbors_results_15 = NeighborsResults(distances=distances[:,:15],indices=indices[:,:15])
        neighbors_results_50 = NeighborsResults(distances=distances[:,:50],indices=indices[:,:50])    

        pcr = scib_metrics.pcr_comparison(X_pre = adata.obsm['Unintegrated'], 
                                    X_post = np.array(adata.obsm[model_key]), 
                                    covariate=np.array(species_labels))
        

        pcr_comparison_fine_list.append(pcr)

        silhouette_batch_fine_list.append(scib_metrics.silhouette_batch(
                                        X = np.array(adata.obsm[model_key]), 
                                        labels = np.array(cell_labels_fine), 
                                        batch = np.array(species_labels)
                                        )
                                    )

        graph_connectivity_fine_list.append(scib_metrics.graph_connectivity(
                                            X = neighbors_results_15, 
                                            labels = cell_labels_fine
                                            )
                                        )


        dict = scib_metrics.nmi_ari_cluster_labels_leiden(
                                        X=neighbors_results_15, 
                                        labels=cell_labels_fine
                                        )
        
        ari_leiden_fine_list.append([dict['ari']])
        nmi_leiden_fine_list.append([dict['nmi']])

        isolated_labels_fine_list.append(scib_metrics.isolated_labels(
                                        X=np.array(adata.obsm[model_key]), 
                                        labels=cell_labels_fine, 
                                        batch=np.array(species_labels)
                                        )
                                    )

        silhouette_label_fine_list.append(scib_metrics.silhouette_label(
                                        X=np.array(adata.obsm[model_key]), 
                                        labels=cell_labels_fine
                                        )
                                    )

        erg_dict_small_sysVI[model_key]['PCR'] =  np.array(pcr_comparison_fine_list).squeeze()
        erg_dict_small_sysVI[model_key]['bASW'] =  np.array(silhouette_batch_fine_list).squeeze()
        erg_dict_small_sysVI[model_key]['GC'] =  np.array(graph_connectivity_fine_list).squeeze()
        erg_dict_small_sysVI[model_key]['Species mixing'] = 0.33 * (np.array(pcr_comparison_fine_list).squeeze() + np.array(silhouette_batch_fine_list).squeeze() +  np.array(graph_connectivity_fine_list).squeeze())  
        
        erg_dict_small_sysVI[model_key]['ARI'] =  np.array(ari_leiden_fine_list).squeeze()
        erg_dict_small_sysVI[model_key]['NMI'] =  np.array(nmi_leiden_fine_list).squeeze()
        erg_dict_small_sysVI[model_key]['isoF1'] =  np.array(isolated_labels_fine_list).squeeze()
        erg_dict_small_sysVI[model_key]['cASW'] =  np.array(silhouette_label_fine_list).squeeze()      
        erg_dict_small_sysVI[model_key]['Biology conservation'] =  0.25 * (np.array(ari_leiden_fine_list).squeeze() + np.array(nmi_leiden_fine_list).squeeze() + np.array(isolated_labels_fine_list).squeeze() + np.array(silhouette_label_fine_list).squeeze())

        erg_dict_small_sysVI[model_key]['Total'] =  0.4 * erg_dict_small_sysVI[model_key]['Species mixing'] + 0.6 * erg_dict_small_sysVI[model_key]['Biology conservation']

        erg_dict_small_sysVI[model_key]['BAS Label transfer'] =  np.array(bas_fine_list).squeeze()

        with open(data_path + dataset + "_erg_dict_small.json", "w") as f:
            json.dump(erg_dict_small_sysVI, f, cls=NumpyEncoder)   

    

In [None]:
erg_dict_small = {}

def label_pred_fine(key, mdata, adata, perm_1, perm_2):
    likelihoods = mdata.mod['human'][perm_2].obsm['nlog_likeli_nns_aligned_latent_space']
         
    fine_labels = mdata.mod['mouse'].obs['cell_type_fine'].to_numpy()
    fine_labels_pred = np.stack([most_frequent(fine_labels[mdata.mod['human'].obsm['ind_nns_aligned_latent_space'][perm_2][i][np.argsort(likelihoods[i])]][:25]) for i in range(len(perm_2))])

    adata.obsm[key] = np.concatenate((mdata.mod['mouse'].obsm['latent_mu'][perm_1], mdata.mod['human'].obsm['latent_mu'][perm_2]))
    adata.obs[key+'pred_fine'] = np.concatenate((np.array(['c']*len(perm_1)), fine_labels_pred))

def calc_metrics_fine(fine_labels_true, fine_labels_pred, common_cells_fine):
    transl_dict = {'Kupffer cells': 'KCs'}

    fine_labels_true = np.array([transl_dict[ct] if ct in transl_dict.keys() else ct for ct in fine_labels_true])

    metrics = {
        "BAS (Fine)": balanced_accuracy_score(fine_labels_true[np.isin(fine_labels_true, common_cells_fine)], fine_labels_pred[np.isin(fine_labels_true, common_cells_fine)]),    
    }
    
    return metrics

dataset = 'liver_human'
context_key = 'mouse'
target_key = 'human'
load_key = 'liver'    

for i in [1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000, 20000, 50000]: 
    
    j=0
    
    load_key = find_file(load_key+'_'+context_key+'_'+target_key+'_25_inter_'+str(i)+'_mdata.h5mu', data_path) 
    if load_key != None:  
        mdata = mu.read_h5mu(load_key)    

    mdata[context_key].obs['system'] = 0
    mdata[target_key].obs['system'] = 1

    context_genes = np.array(mdata[context_key].var['human_gene_names'])
    target_genes = np.array(mdata[target_key].var['human_gene_names'])

    ret_vec, ind_a, ind_b =  np.intersect1d(target_genes, context_genes, return_indices=True)
    mdata_target = mdata[target_key][:, ind_a]
    mdata_context = mdata[context_key][:, ind_b]    
    
    #size = (i+1)*10000
    perm_1 = np.random.permutation(mdata_context.n_obs)[:min(i*10,165680)]
    perm_2 = np.random.permutation(mdata_target.n_obs)[:min(i,16568)]
    
    for j in range(10):
        model_key = str(j)+'_'+str(i)
        load_key = find_file(load_key+'_'+context_key+'_'+target_key+'_25_inter_'+str(j)+'_size_'+str(i), data_path) 
        if load_key != None:  
            mdata = mu.read_h5mu(load_key)    
    
        mdata[context_key].obs['system'] = 0
        mdata[target_key].obs['system'] = 1

        context_genes = np.array(mdata[context_key].var['human_gene_names'])
        target_genes = np.array(mdata[target_key].var['human_gene_names'])

        ret_vec, ind_a, ind_b =  np.intersect1d(target_genes, context_genes, return_indices=True)
        mdata_target = mdata[target_key][:, ind_a]
        mdata_context = mdata[context_key][:, ind_b]

        mdata_target.var_names = target_genes[ind_a]
        mdata_context.var_names = context_genes[ind_b]

        adata = ad.concat([mdata_context[perm_1], mdata_target[perm_2]], axis=0, join='inner')#ad.concat([mdata_context[perm], mdata_target], axis=0, join='inner')

        sc.pp.normalize_total(adata, target_sum=1e4)
        sc.pp.log1p(adata)

        sc.tl.pca(adata, n_comps=10)
        adata.obsm["Unintegrated"] = adata.obsm["X_pca"]    
            
        label_pred_fine(model_key, mdata, adata, perm_1, perm_2)

        human_ind = np.array(adata.obs.system == 1)
        mouse_ind = np.array(adata.obs.system == 0)
        human_label_fine = np.array(adata.obs['cell_type_fine'][human_ind])
        mouse_label_fine = np.array(adata.obs['cell_type_fine'][mouse_ind])       
        common_cells_fine = np.intersect1d(mouse_label_fine, human_label_fine)
        species_labels = adata.obs['system']
        cell_labels_fine = adata.obs['cell_type_fine']
        
        erg_dict_small[model_key] = {}

        pcr_comparison_fine_list = []
        silhouette_batch_fine_list = []
        graph_connectivity_fine_list = []

        ari_leiden_fine_list = []
        nmi_leiden_fine_list = []
        isolated_labels_fine_list = []
        silhouette_label_fine_list = []
 
        bas_fine_list = []      
        
        human_label_fine_pred = adata.obs[model_key+'pred_fine'][human_ind]
        metrics = calc_metrics_fine(human_label_fine, human_label_fine_pred, common_cells_fine)

        bas_fine_list.append(metrics["BAS (Fine)"])

        sc.pp.neighbors(adata, n_neighbors=50, use_rep=model_key) #, key_added=model_key
        distances, indices = scib_metrics.utils.convert_knn_graph_to_idx(adata.obsp.get('distances', adata.obsp['connectivities'])) # adata.obsp.get(model_key+'_distances', adata.obsp[model_key+'_connectivities'])
        neighbors_results_15 = NeighborsResults(distances=distances[:,:15],indices=indices[:,:15])
        neighbors_results_50 = NeighborsResults(distances=distances[:,:50],indices=indices[:,:50])    

        pcr = scib_metrics.pcr_comparison(X_pre = adata.obsm['Unintegrated'], 
                                    X_post = np.array(adata.obsm[model_key]), 
                                    covariate=np.array(species_labels))
        

        pcr_comparison_fine_list.append(pcr)

        silhouette_batch_fine_list.append(scib_metrics.silhouette_batch(
                                        X = np.array(adata.obsm[model_key]), 
                                        labels = np.array(cell_labels_fine), 
                                        batch = np.array(species_labels)
                                        )
                                    )

        graph_connectivity_fine_list.append(scib_metrics.graph_connectivity(
                                            X = neighbors_results_15, 
                                            labels = cell_labels_fine
                                            )
                                        )


        dict = scib_metrics.nmi_ari_cluster_labels_leiden(
                                        X=neighbors_results_15, 
                                        labels=cell_labels_fine
                                        )
        
        ari_leiden_fine_list.append([dict['ari']])
        nmi_leiden_fine_list.append([dict['nmi']])

        isolated_labels_fine_list.append(scib_metrics.isolated_labels(
                                        X=np.array(adata.obsm[model_key]), 
                                        labels=cell_labels_fine, 
                                        batch=np.array(species_labels)
                                        )
                                    )

        silhouette_label_fine_list.append(scib_metrics.silhouette_label(
                                        X=np.array(adata.obsm[model_key]), 
                                        labels=cell_labels_fine
                                        )
                                    )

        erg_dict_small[model_key]['PCR'] =  np.array(pcr_comparison_fine_list).squeeze()
        erg_dict_small[model_key]['bASW'] =  np.array(silhouette_batch_fine_list).squeeze()
        erg_dict_small[model_key]['GC'] =  np.array(graph_connectivity_fine_list).squeeze()
        erg_dict_small[model_key]['Species mixing'] = 0.33 * (np.array(pcr_comparison_fine_list).squeeze() + np.array(silhouette_batch_fine_list).squeeze() +  np.array(graph_connectivity_fine_list).squeeze())  
        
        erg_dict_small[model_key]['ARI'] =  np.array(ari_leiden_fine_list).squeeze()
        erg_dict_small[model_key]['NMI'] =  np.array(nmi_leiden_fine_list).squeeze()
        erg_dict_small[model_key]['isoF1'] =  np.array(isolated_labels_fine_list).squeeze()
        erg_dict_small[model_key]['cASW'] =  np.array(silhouette_label_fine_list).squeeze()      
        erg_dict_small[model_key]['Biology conservation'] =  0.25 * (np.array(ari_leiden_fine_list).squeeze() + np.array(nmi_leiden_fine_list).squeeze() + np.array(isolated_labels_fine_list).squeeze() + np.array(silhouette_label_fine_list).squeeze())

        erg_dict_small[model_key]['Total'] =  0.4 * erg_dict_small[model_key]['Species mixing'] + 0.6 * erg_dict_small[model_key]['Biology conservation']

        erg_dict_small[model_key]['BAS Label transfer'] =  np.array(bas_fine_list).squeeze()

        with open(data_path + dataset + "_erg_dict_small.json", "w") as f:
            json.dump(erg_dict_small, f, cls=NumpyEncoder)   


In [None]:

eval_size_1=10000
eval_size_2=100000

def label_pred_fine(key, mdata, adata, perm_1, perm_2):
    likelihoods = mdata.mod['human'].obsm['nlog_likeli_nns_aligned_latent_space']
    fine_labels = mdata.mod['mouse'].obs['cell_type_fine'].to_numpy()
    fine_labels_pred = np.stack([most_frequent(fine_labels[mdata.mod['human'].obsm['ind_nns_aligned_latent_space'][i][np.argsort(likelihoods[i])]][:25]) for i in range(mdata.mod['human'].n_obs)])
    adata.obsm[key] = np.concatenate((mdata.mod['mouse'].obsm['latent_mu'][perm_2], mdata.mod['human'].obsm['latent_mu'][perm_1]))
    #print(np.array(['c']*mdata.mod['mouse'][perm_2].n_obs), fine_labels_pred)
    adata.obs[key+'pred_fine'] = np.concatenate((np.array(['c']*mdata.mod['mouse'][perm_2].n_obs), fine_labels_pred[perm_1]))
  

for i in range(10):
    model_key = 'full_data'+str(i)
    load_key = load_key = find_file(load_key+'_'+context_key+'_'+target_key+'_25_inter_'+str(i)+'.h5mu', data_path) 
    if load_key != None:  
        mdata = mu.read_h5mu(load_key)    

    mdata[context_key].obs['system'] = 0
    mdata[target_key].obs['system'] = 1

    context_genes = np.array(mdata[context_key].var['human_gene_names'])
    target_genes = np.array(mdata[target_key].var['human_gene_names'])

    ret_vec, ind_a, ind_b =  np.intersect1d(target_genes, context_genes, return_indices=True)
    mdata_target = mdata[target_key][:, ind_a]
    mdata_context = mdata[context_key][:, ind_b]

    mdata_target.var_names = target_genes[ind_a]
    mdata_context.var_names = context_genes[ind_b]

    perm_1 = np.random.permutation(mdata_target.n_obs)[:eval_size_1]
    perm_2 = np.random.permutation(mdata_context.n_obs)[:eval_size_2]
    adata = ad.concat([mdata_context[perm_2], mdata_target[perm_1]], axis=0, join='inner')

    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    sc.tl.pca(adata, n_comps=10)
    adata.obsm["Unintegrated"] = adata.obsm["X_pca"]    
        
    label_pred_fine(model_key, mdata, adata, perm_1, perm_2)

    human_ind = np.array(adata.obs.system == 1)
    mouse_ind = np.array(adata.obs.system == 0)
    human_label_fine = np.array(adata.obs['cell_type_fine'][human_ind])
    mouse_label_fine = np.array(adata.obs['cell_type_fine'][mouse_ind])       
    common_cells_fine = np.intersect1d(mouse_label_fine, human_label_fine)
    species_labels = adata.obs['system']
    cell_labels_fine = adata.obs['cell_type_fine']
    
    erg_dict_small['full_data'] = {}
    
    human_label_fine_pred = adata.obs[model_key+'pred_fine'][human_ind]
    metrics = calc_metrics_fine(human_label_fine, human_label_fine_pred, common_cells_fine)

    bas_fine_list.append(metrics["BAS (Fine)"])

    sc.pp.neighbors(adata, n_neighbors=50, use_rep=model_key) #, key_added=model_key
    distances, indices = scib_metrics.utils.convert_knn_graph_to_idx(adata.obsp.get('distances', adata.obsp['connectivities'])) # adata.obsp.get(model_key+'_distances', adata.obsp[model_key+'_connectivities'])
    neighbors_results_15 = NeighborsResults(distances=distances[:,:15],indices=indices[:,:15])
    neighbors_results_50 = NeighborsResults(distances=distances[:,:50],indices=indices[:,:50])    

    pcr = scib_metrics.pcr_comparison(X_pre = adata.obsm['Unintegrated'], 
                                X_post = np.array(adata.obsm[model_key]), 
                                covariate=np.array(species_labels))
    

    pcr_comparison_fine_list.append(pcr)

    silhouette_batch_fine_list.append(scib_metrics.silhouette_batch(
                                    X = np.array(adata.obsm[model_key]), 
                                    labels = np.array(cell_labels_fine), 
                                    batch = np.array(species_labels)
                                    )
                                )

    graph_connectivity_fine_list.append(scib_metrics.graph_connectivity(
                                        X = neighbors_results_15, 
                                        labels = cell_labels_fine
                                        )
                                    )

    dict = scib_metrics.nmi_ari_cluster_labels_leiden(
                                    X=neighbors_results_15, 
                                    labels=cell_labels_fine
                                    )
    
    ari_leiden_fine_list.append([dict['ari']])
    nmi_leiden_fine_list.append([dict['nmi']])

    isolated_labels_fine_list.append(scib_metrics.isolated_labels(
                                    X=np.array(adata.obsm[model_key]), 
                                    labels=cell_labels_fine, 
                                    batch=np.array(species_labels)
                                    )
                                )

    silhouette_label_fine_list.append(scib_metrics.silhouette_label(
                                    X=np.array(adata.obsm[model_key]), 
                                    labels=cell_labels_fine
                                    )
                                )

erg_dict_small['full_data']['bASW'] =  np.array(silhouette_batch_fine_list).squeeze()
erg_dict_small['full_data']['PCR'] =  np.array(pcr_comparison_fine_list).squeeze()
erg_dict_small['full_data']['GC'] =  np.array(graph_connectivity_fine_list).squeeze()
erg_dict_small['full_data']['Species mixing'] = 0.33 * (np.array(silhouette_batch_fine_list).squeeze() + np.array(pcr_comparison_fine_list).squeeze() + np.array(graph_connectivity_fine_list).squeeze())  

erg_dict_small['full_data']['ARI'] =  np.array(ari_leiden_fine_list).squeeze()
erg_dict_small['full_data']['NMI'] =  np.array(nmi_leiden_fine_list).squeeze()
erg_dict_small['full_data']['isoF1'] =  np.array(isolated_labels_fine_list).squeeze()
erg_dict_small['full_data']['cASW'] =  np.array(silhouette_label_fine_list).squeeze()      
erg_dict_small['full_data']['Biology conservation'] =  0.25 * (np.array(ari_leiden_fine_list).squeeze() + np.array(nmi_leiden_fine_list).squeeze() + np.array(isolated_labels_fine_list).squeeze() + np.array(silhouette_label_fine_list).squeeze())

erg_dict_small['full_data']['Total'] =  0.4 * erg_dict_small['full_data']['Species mixing'] + 0.6 * erg_dict_small['full_data']['Biology conservation']

erg_dict_small['full_data']['BAS Label transfer'] =  np.array(bas_fine_list).squeeze()  

In [None]:
erg_dict_mean_small = {}
erg_dict_std_small = {}

for i in [1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000, 20000, 50000]: 
    key = str(i)
    
    pcr_comparison_fine_list = []
    silhouette_batch_fine_list = []
    graph_connectivity_fine_list = []
    kbet_fine_list = []
    ari_leiden_fine_list = []
    nmi_leiden_fine_list = []
    isolated_labels_fine_list = []
    silhouette_label_fine_list = []
    bas_fine_list = []  

    bio_cons_fine_list = []
    species_mix_fine_list = []
    total_list = []  
    
    erg_dict_mean_small[key] = {}
    erg_dict_std_small[key] = {}
    
    for j in range(3):
        model_key = str(j)+'_'+str(i)
        
        pcr_comparison_fine_list.append(erg_dict_small[model_key]['PCR'])
        silhouette_batch_fine_list.append(erg_dict_small[model_key]['bASW'])
        graph_connectivity_fine_list.append(erg_dict_small[model_key]['GC'])
        ari_leiden_fine_list.append(erg_dict_small[model_key]['ARI'])
        nmi_leiden_fine_list.append(erg_dict_small[model_key]['NMI'])
        isolated_labels_fine_list.append(erg_dict_small[model_key]['isoF1'])
        silhouette_label_fine_list.append(erg_dict_small[model_key]['cASW'])
        bas_fine_list.append(erg_dict_small[model_key]['BAS Label transfer'])

        bio_cons_fine_list.append(erg_dict_small[model_key]['Biology conservation'])
        species_mix_fine_list.append(erg_dict_small[model_key]['Species mixing'])
        total_list.append(erg_dict_small[model_key]['Total'])     
        
    erg_dict_mean_small[key]['PCR'] = np.mean(np.array(pcr_comparison_fine_list))
    erg_dict_mean_small[key]['bASW'] = np.mean(np.array(silhouette_batch_fine_list))
    erg_dict_mean_small[key]['GC'] = np.mean(np.array(graph_connectivity_fine_list))
    erg_dict_mean_small[key]['ARI'] = np.mean(np.array(ari_leiden_fine_list)) 
    erg_dict_mean_small[key]['NMI'] = np.mean(np.array(nmi_leiden_fine_list)) 
    erg_dict_mean_small[key]['isoF1'] = np.mean(np.array(isolated_labels_fine_list))
    erg_dict_mean_small[key]['cASW'] = np.mean(np.array(silhouette_label_fine_list)) 
    erg_dict_mean_small[key]['BAS'] = np.mean(np.array(bas_fine_list)) 
    erg_dict_mean_small[key]['Biology'] = np.mean(np.array(bio_cons_fine_list)) 
    erg_dict_mean_small[key]['Species'] = np.mean(np.array(species_mix_fine_list)) 
    erg_dict_mean_small[key]['Total'] = np.mean(np.array(total_list))  
    
    erg_dict_std_small[key]['PCR'] = np.std(np.array(pcr_comparison_fine_list)) 
    erg_dict_std_small[key]['bASW'] = np.std(np.array(silhouette_batch_fine_list)) 
    erg_dict_std_small[key]['GC'] = np.std(np.array(graph_connectivity_fine_list)) 
    erg_dict_std_small[key]['ARI'] = np.std(np.array(ari_leiden_fine_list)) 
    erg_dict_std_small[key]['NMI'] = np.std(np.array(nmi_leiden_fine_list)) 
    erg_dict_std_small[key]['isoF1'] = np.std(np.array(isolated_labels_fine_list)) 
    erg_dict_std_small[key]['cASW'] = np.std(np.array(silhouette_label_fine_list)) 
    erg_dict_std_small[key]['BAS'] = np.std(np.array(bas_fine_list)) 
    erg_dict_std_small[key]['Biology'] = np.std(np.array(bio_cons_fine_list)) 
    erg_dict_std_small[key]['Species'] = np.std(np.array(species_mix_fine_list)) 
    erg_dict_std_small[key]['Total'] = np.std(np.array(total_list))      


key = str(146839)
model_key = 'full_data'

erg_dict_mean_small[key] = {}
erg_dict_std_small[key] = {}

silhouette_batch_fine_list = erg_dict_small[model_key]['bASW']
graph_connectivity_fine_list = erg_dict_small[model_key]['GC']
pcr_comparison_fine_list = erg_dict_small[model_key]['PCR']
ari_leiden_fine_list = erg_dict_small[model_key]['ARI']
nmi_leiden_fine_list = erg_dict_small[model_key]['NMI']
isolated_labels_fine_list = erg_dict_small[model_key]['isoF1']
silhouette_label_fine_list = erg_dict_small[model_key]['cASW']
bas_fine_list = erg_dict_small[model_key]['BAS Label transfer']

bio_cons_fine_list = erg_dict_small[model_key]['Biology conservation']
species_mix_fine_list = erg_dict_small[model_key]['Species mixing']
total_list = erg_dict_small[model_key]['Total']     
    
erg_dict_mean_small[key]['bASW'] = np.mean(np.array(silhouette_batch_fine_list))
erg_dict_mean_small[key]['PCR'] = np.mean(np.array(pcr_comparison_fine_list))
erg_dict_mean_small[key]['GC'] = np.mean(np.array(graph_connectivity_fine_list))
erg_dict_mean_small[key]['ARI'] = np.mean(np.array(ari_leiden_fine_list)) 
erg_dict_mean_small[key]['NMI'] = np.mean(np.array(nmi_leiden_fine_list)) 
erg_dict_mean_small[key]['isoF1'] = np.mean(np.array(isolated_labels_fine_list))
erg_dict_mean_small[key]['cASW'] = np.mean(np.array(silhouette_label_fine_list)) 
erg_dict_mean_small[key]['BAS'] = np.mean(np.array(bas_fine_list)) 
erg_dict_mean_small[key]['Biology'] = np.mean(np.array(bio_cons_fine_list)) 
erg_dict_mean_small[key]['Species'] = np.mean(np.array(species_mix_fine_list)) 
erg_dict_mean_small[key]['Total'] = np.mean(np.array(total_list))  

erg_dict_std_small[key]['bASW'] = np.std(np.array(silhouette_batch_fine_list)) 
erg_dict_std_small[key]['PCR'] = np.std(np.array(pcr_comparison_fine_list)) 
erg_dict_std_small[key]['GC'] = np.std(np.array(graph_connectivity_fine_list)) 
erg_dict_std_small[key]['ARI'] = np.std(np.array(ari_leiden_fine_list)) 
erg_dict_std_small[key]['NMI'] = np.std(np.array(nmi_leiden_fine_list)) 
erg_dict_std_small[key]['isoF1'] = np.std(np.array(isolated_labels_fine_list)) 
erg_dict_std_small[key]['cASW'] = np.std(np.array(silhouette_label_fine_list)) 
erg_dict_std_small[key]['BAS'] = np.std(np.array(bas_fine_list)) 
erg_dict_std_small[key]['Biology'] = np.std(np.array(bio_cons_fine_list)) 
erg_dict_std_small[key]['Species'] = np.std(np.array(species_mix_fine_list)) 
erg_dict_std_small[key]['Total'] = np.std(np.array(total_list))     

# Reduced gene set

In [None]:
size = 50000

def label_pred_fine(key, mdata, adata):
    likelihoods = mdata.mod['human'].obsm['nlog_likeli_nns_aligned_latent_space']
         
    fine_labels = mdata.mod['mouse'].obs['cell_type_fine'].to_numpy()
    fine_labels_pred = np.stack([most_frequent(fine_labels[mdata.mod['human'].obsm['ind_nns_aligned_latent_space'][i][np.argsort(likelihoods[i])]][:25]) for i in range(mdata.mod['human'].n_obs)])

    adata.obsm[key] = np.concatenate((mdata.mod['mouse'].obsm['latent_mu'], mdata.mod['human'].obsm['latent_mu']))[perm]
    adata.obs[key+'pred_fine'] = np.concatenate((np.array(['c']*len(fine_labels)), fine_labels_pred))[perm]

def calc_metrics_fine(fine_labels_true, fine_labels_pred, common_cells_fine):
    transl_dict = {'Kupffer cells': 'KCs'}

    fine_labels_true = np.array([transl_dict[ct] if ct in transl_dict.keys() else ct for ct in fine_labels_true])

    metrics = {
        "BAS (Fine)": balanced_accuracy_score(fine_labels_true[np.isin(fine_labels_true, common_cells_fine)], fine_labels_pred[np.isin(fine_labels_true, common_cells_fine)]),    
    }
    
    return metrics

dataset = 'liver_human'
context_key = 'mouse'
target_key = 'human'
load_key = 'liver'    

erg_dict_hom = {}

for i in range(8):
    for j in range(10):
        model_key = str(j)+'_'+str(i*200)
        load_key = find_file(load_key+'_'+context_key+'_'+target_key+'_25_inter_'+str(j)+'_reduced_'+str(i*200)+'.h5mu', data_path) 
        if load_key != None:  
            mdata = mu.read_h5mu(load_key)    
    
        mdata[context_key].obs['system'] = 0
        mdata[target_key].obs['system'] = 1

        context_genes = np.array(mdata[context_key].var['human_gene_names'])
        target_genes = np.array(mdata[target_key].var['human_gene_names'])

        ret_vec, ind_a, ind_b =  np.intersect1d(target_genes, context_genes, return_indices=True)
        mdata_target = mdata[target_key][:, ind_a]
        mdata_context = mdata[context_key][:, ind_b]

        mdata_target.var_names = target_genes[ind_a]
        mdata_context.var_names = context_genes[ind_b]

        perm = np.random.permutation(mdata_context.n_obs)[:size]

        adata = ad.concat([mdata_context, mdata_target], axis=0, join='inner')#[perm]

        perm = np.random.permutation(adata.n_obs)[:eval_size]
        adata = adata[perm]  

        sc.pp.normalize_total(adata, target_sum=1e4)
        sc.pp.log1p(adata)

        sc.tl.pca(adata, n_comps=10)
        adata.obsm["Unintegrated"] = adata.obsm["X_pca"]    
            
        label_pred_fine(model_key, mdata, adata)

        human_ind = np.array(adata.obs.system == 1)
        mouse_ind = np.array(adata.obs.system == 0)
        human_label_fine = np.array(adata.obs['cell_type_fine'][human_ind])
        mouse_label_fine = np.array(adata.obs['cell_type_fine'][mouse_ind])       
        common_cells_fine = np.intersect1d(mouse_label_fine, human_label_fine)
        species_labels = adata.obs['system']
        cell_labels_fine = adata.obs['cell_type_fine']
        
        erg_dict_hom[model_key] = {}

        pcr_comparison_fine_list = []
        silhouette_batch_fine_list = []
        graph_connectivity_fine_list = []
        kbet_fine_list = []

        ari_leiden_fine_list = []
        nmi_leiden_fine_list = []
        isolated_labels_fine_list = []
        silhouette_label_fine_list = []
 
        bas_fine_list = []      
        
        human_label_fine_pred = adata.obs[model_key+'pred_fine'][human_ind]
        metrics = calc_metrics_fine(human_label_fine, human_label_fine_pred, common_cells_fine)

        bas_fine_list.append(metrics["BAS (Fine)"])

        sc.pp.neighbors(adata, n_neighbors=50, use_rep=model_key) #, key_added=model_key
        distances, indices = scib_metrics.utils.convert_knn_graph_to_idx(adata.obsp.get('distances', adata.obsp['connectivities'])) # adata.obsp.get(model_key+'_distances', adata.obsp[model_key+'_connectivities'])
        neighbors_results_15 = NeighborsResults(distances=distances[:,:15],indices=indices[:,:15])
        neighbors_results_50 = NeighborsResults(distances=distances[:,:50],indices=indices[:,:50])    

        silhouette_batch_fine_list.append(scib_metrics.silhouette_batch(
                                        X = np.array(adata.obsm[model_key]), 
                                        labels = np.array(cell_labels_fine), 
                                        batch = np.array(species_labels)
                                        )
                                    )

        graph_connectivity_fine_list.append(scib_metrics.graph_connectivity(
                                            X = neighbors_results_15, 
                                            labels = cell_labels_fine
                                            )
                                        )
    
        kbet_fine_list.append(scib_metrics.kbet_per_label(
                                        X = neighbors_results_50, 
                                        batches = species_labels,
                                        labels = cell_labels_fine
                                        )
                                    )    

        dict = scib_metrics.nmi_ari_cluster_labels_leiden(
                                        X=neighbors_results_15, 
                                        labels=cell_labels_fine
                                        )
        
        ari_leiden_fine_list.append([dict['ari']])
        nmi_leiden_fine_list.append([dict['nmi']])

        isolated_labels_fine_list.append(scib_metrics.isolated_labels(
                                        X=np.array(adata.obsm[model_key]), 
                                        labels=cell_labels_fine, 
                                        batch=np.array(species_labels)
                                        )
                                    )

        silhouette_label_fine_list.append(scib_metrics.silhouette_label(
                                        X=np.array(adata.obsm[model_key]), 
                                        labels=cell_labels_fine
                                        )
                                    )

        erg_dict_hom[model_key]['bASW'] =  np.array(silhouette_batch_fine_list).squeeze()
        erg_dict_hom[model_key]['kBET'] =  np.array(kbet_fine_list).squeeze()
        erg_dict_hom[model_key]['GC'] =  np.array(graph_connectivity_fine_list).squeeze()
        erg_dict_hom[model_key]['Species mixing'] = 0.33 * (np.array(silhouette_batch_fine_list).squeeze() + np.array(kbet_fine_list).squeeze() + np.array(graph_connectivity_fine_list).squeeze())  
        
        erg_dict_hom[model_key]['ARI'] =  np.array(ari_leiden_fine_list).squeeze()
        erg_dict_hom[model_key]['NMI'] =  np.array(nmi_leiden_fine_list).squeeze()
        erg_dict_hom[model_key]['isoF1'] =  np.array(isolated_labels_fine_list).squeeze()
        erg_dict_hom[model_key]['cASW'] =  np.array(silhouette_label_fine_list).squeeze()      
        erg_dict_hom[model_key]['Biology conservation'] =  0.25 * (np.array(ari_leiden_fine_list).squeeze() + np.array(nmi_leiden_fine_list).squeeze() + np.array(isolated_labels_fine_list).squeeze() + np.array(silhouette_label_fine_list).squeeze())

        erg_dict_hom[model_key]['Total'] =  0.4 * erg_dict_hom[model_key]['Species mixing'] + 0.6 * erg_dict_hom[model_key]['Biology conservation']

        erg_dict_hom[model_key]['BAS Label transfer'] =  np.array(bas_fine_list).squeeze()

        with open(data_path + dataset + "_erg_dict_hom.json", "w") as f:
            json.dump(erg_dict_hom, f, cls=NumpyEncoder)   


pcr_comparison_fine_list = []
silhouette_batch_fine_list = []
graph_connectivity_fine_list = []
kbet_fine_list = []
ari_leiden_fine_list = []
nmi_leiden_fine_list = []
isolated_labels_fine_list = []
silhouette_label_fine_list = []
bas_fine_list = []      

for i in range(10):
    load_key = find_file(load_key+'_'+context_key+'_'+target_key+'_25_inter_'+str(j)+'_no_shared.h5mu', data_path) 
    if load_key != None:  
        mdata = mu.read_h5mu(load_key)    

    mdata[context_key].obs['system'] = 0
    mdata[target_key].obs['system'] = 1

    context_genes = np.array(mdata[context_key].var['human_gene_names'])
    target_genes = np.array(mdata[target_key].var['human_gene_names'])

    ret_vec, ind_a, ind_b =  np.intersect1d(target_genes, context_genes, return_indices=True)
    mdata_target = mdata[target_key][:, ind_a]
    mdata_context = mdata[context_key][:, ind_b]

    mdata_target.var_names = target_genes[ind_a]
    mdata_context.var_names = context_genes[ind_b]

    perm = np.random.permutation(mdata_context.n_obs)[:size]

    adata = ad.concat([mdata_context, mdata_target], axis=0, join='inner')#[perm]

    perm = np.random.permutation(adata.n_obs)[:eval_size]
    adata = adata[perm]  
 
    label_pred_fine(model_key, mdata, adata)

    human_ind = np.array(adata.obs.system == 1)
    mouse_ind = np.array(adata.obs.system == 0)
    human_label_fine = np.array(adata.obs['cell_type_fine'][human_ind])
    mouse_label_fine = np.array(adata.obs['cell_type_fine'][mouse_ind])       
    common_cells_fine = np.intersect1d(mouse_label_fine, human_label_fine)
    species_labels = adata.obs['system']
    cell_labels_fine = adata.obs['cell_type_fine']
    
    erg_dict_hom[model_key] = {}

    
    human_label_fine_pred = adata.obs[model_key+'pred_fine'][human_ind]
    metrics = calc_metrics_fine(human_label_fine, human_label_fine_pred, common_cells_fine)

    bas_fine_list.append(metrics["BAS (Fine)"])

    sc.pp.neighbors(adata, n_neighbors=50, use_rep=model_key) #, key_added=model_key
    distances, indices = scib_metrics.utils.convert_knn_graph_to_idx(adata.obsp.get('distances', adata.obsp['connectivities'])) # adata.obsp.get(model_key+'_distances', adata.obsp[model_key+'_connectivities'])
    neighbors_results_15 = NeighborsResults(distances=distances[:,:15],indices=indices[:,:15])
    neighbors_results_50 = NeighborsResults(distances=distances[:,:50],indices=indices[:,:50])    


    silhouette_batch_fine_list.append(scib_metrics.silhouette_batch(
                                    X = np.array(adata.obsm[model_key]), 
                                    labels = np.array(cell_labels_fine), 
                                    batch = np.array(species_labels)
                                    )
                                )

    graph_connectivity_fine_list.append(scib_metrics.graph_connectivity(
                                        X = neighbors_results_15, 
                                        labels = cell_labels_fine
                                        )
                                    )

    kbet_fine_list.append(scib_metrics.kbet_per_label(
                                    X = neighbors_results_50, 
                                    batches = species_labels,
                                    labels = cell_labels_fine
                                    )
                                )    

    dict = scib_metrics.nmi_ari_cluster_labels_leiden(
                                    X=neighbors_results_15, 
                                    labels=cell_labels_fine
                                    )
    
    ari_leiden_fine_list.append([dict['ari']])
    nmi_leiden_fine_list.append([dict['nmi']])

    isolated_labels_fine_list.append(scib_metrics.isolated_labels(
                                    X=np.array(adata.obsm[model_key]), 
                                    labels=cell_labels_fine, 
                                    batch=np.array(species_labels)
                                    )
                                )

    silhouette_label_fine_list.append(scib_metrics.silhouette_label(
                                    X=np.array(adata.obsm[model_key]), 
                                    labels=cell_labels_fine
                                    )
                                )

#erg_dict_hom[model_key]['PCR'] =  np.array(pcr_comparison_fine_list).squeeze()
erg_dict_hom['no_shared']['bASW'] =  np.array(silhouette_batch_fine_list).squeeze()
erg_dict_hom['no_shared']['kBET'] =  np.array(kbet_fine_list).squeeze()
erg_dict_hom['no_shared']['GC'] =  np.array(graph_connectivity_fine_list).squeeze()
erg_dict_hom['no_shared']['Species mixing'] = 0.33 * (np.array(silhouette_batch_fine_list).squeeze() + np.array(kbet_fine_list).squeeze() + np.array(graph_connectivity_fine_list).squeeze())  

erg_dict_hom['no_shared']['ARI'] =  np.array(ari_leiden_fine_list).squeeze()
erg_dict_hom['no_shared']['NMI'] =  np.array(nmi_leiden_fine_list).squeeze()
erg_dict_hom['no_shared']['isoF1'] =  np.array(isolated_labels_fine_list).squeeze()
erg_dict_hom['no_shared']['cASW'] =  np.array(silhouette_label_fine_list).squeeze()      
erg_dict_hom['no_shared']['Biology conservation'] =  0.25 * (np.array(ari_leiden_fine_list).squeeze() + np.array(nmi_leiden_fine_list).squeeze() + np.array(isolated_labels_fine_list).squeeze() + np.array(silhouette_label_fine_list).squeeze())

erg_dict_hom['no_shared']['Total'] =  0.4 * erg_dict_hom['no_shared']['Species mixing'] + 0.6 * erg_dict_hom['no_shared']['Biology conservation']

erg_dict_hom['no_shared']['BAS Label transfer'] =  np.array(bas_fine_list).squeeze()

with open(data_path + dataset + "_erg_dict_hom.json", "w") as f:
    json.dump(erg_dict_hom, f, cls=NumpyEncoder)   


pcr_comparison_fine_list = []
silhouette_batch_fine_list = []
graph_connectivity_fine_list = []
kbet_fine_list = []
ari_leiden_fine_list = []
nmi_leiden_fine_list = []
isolated_labels_fine_list = []
silhouette_label_fine_list = []
bas_fine_list = []      

for i in range(10):
    load_key = find_file(load_key+'_'+context_key+'_'+target_key+'_25_inter_'+str(j)+'_full_genes.h5mu', data_path) 
    if load_key != None:  
        mdata = mu.read_h5mu(load_key)    

    mdata[context_key].obs['system'] = 0
    mdata[target_key].obs['system'] = 1

    context_genes = np.array(mdata[context_key].var['human_gene_names'])
    target_genes = np.array(mdata[target_key].var['human_gene_names'])

    ret_vec, ind_a, ind_b =  np.intersect1d(target_genes, context_genes, return_indices=True)
    mdata_target = mdata[target_key][:, ind_a]
    mdata_context = mdata[context_key][:, ind_b]

    mdata_target.var_names = target_genes[ind_a]
    mdata_context.var_names = context_genes[ind_b]

    perm = np.random.permutation(mdata_context.n_obs)[:size]

    adata = ad.concat([mdata_context, mdata_target], axis=0, join='inner')#[perm]

    perm = np.random.permutation(adata.n_obs)[:eval_size]
    adata = adata[perm]  

    label_pred_fine(model_key, mdata, adata)

    human_ind = np.array(adata.obs.system == 1)
    mouse_ind = np.array(adata.obs.system == 0)
    human_label_fine = np.array(adata.obs['cell_type_fine'][human_ind])
    mouse_label_fine = np.array(adata.obs['cell_type_fine'][mouse_ind])       
    common_cells_fine = np.intersect1d(mouse_label_fine, human_label_fine)
    species_labels = adata.obs['system']
    cell_labels_fine = adata.obs['cell_type_fine']
    
    erg_dict_hom[model_key] = {}

    
    human_label_fine_pred = adata.obs[model_key+'pred_fine'][human_ind]
    metrics = calc_metrics_fine(human_label_fine, human_label_fine_pred, common_cells_fine)

    bas_fine_list.append(metrics["BAS (Fine)"])

    sc.pp.neighbors(adata, n_neighbors=50, use_rep=model_key) #, key_added=model_key
    distances, indices = scib_metrics.utils.convert_knn_graph_to_idx(adata.obsp.get('distances', adata.obsp['connectivities'])) # adata.obsp.get(model_key+'_distances', adata.obsp[model_key+'_connectivities'])
    neighbors_results_15 = NeighborsResults(distances=distances[:,:15],indices=indices[:,:15])
    neighbors_results_50 = NeighborsResults(distances=distances[:,:50],indices=indices[:,:50])    


    silhouette_batch_fine_list.append(scib_metrics.silhouette_batch(
                                    X = np.array(adata.obsm[model_key]), 
                                    labels = np.array(cell_labels_fine), 
                                    batch = np.array(species_labels)
                                    )
                                )

    graph_connectivity_fine_list.append(scib_metrics.graph_connectivity(
                                        X = neighbors_results_15, 
                                        labels = cell_labels_fine
                                        )
                                    )

    kbet_fine_list.append(scib_metrics.kbet_per_label(
                                    X = neighbors_results_50, 
                                    batches = species_labels,
                                    labels = cell_labels_fine
                                    )
                                )    

    dict = scib_metrics.nmi_ari_cluster_labels_leiden(
                                    X=neighbors_results_15, 
                                    labels=cell_labels_fine
                                    )
    
    ari_leiden_fine_list.append([dict['ari']])
    nmi_leiden_fine_list.append([dict['nmi']])

    isolated_labels_fine_list.append(scib_metrics.isolated_labels(
                                    X=np.array(adata.obsm[model_key]), 
                                    labels=cell_labels_fine, 
                                    batch=np.array(species_labels)
                                    )
                                )

    silhouette_label_fine_list.append(scib_metrics.silhouette_label(
                                    X=np.array(adata.obsm[model_key]), 
                                    labels=cell_labels_fine
                                    )
                                )

erg_dict_hom['full_genes']['bASW'] =  np.array(silhouette_batch_fine_list).squeeze()
erg_dict_hom['full_genes']['kBET'] =  np.array(kbet_fine_list).squeeze()
erg_dict_hom['full_genes']['GC'] =  np.array(graph_connectivity_fine_list).squeeze()
erg_dict_hom['full_genes']['Species mixing'] = 0.33 * (np.array(silhouette_batch_fine_list).squeeze() + np.array(kbet_fine_list).squeeze() + np.array(graph_connectivity_fine_list).squeeze())  

erg_dict_hom['full_genes']['ARI'] =  np.array(ari_leiden_fine_list).squeeze()
erg_dict_hom['full_genes']['NMI'] =  np.array(nmi_leiden_fine_list).squeeze()
erg_dict_hom['full_genes']['isoF1'] =  np.array(isolated_labels_fine_list).squeeze()
erg_dict_hom['full_genes']['cASW'] =  np.array(silhouette_label_fine_list).squeeze()      
erg_dict_hom['full_genes']['Biology conservation'] =  0.25 * (np.array(ari_leiden_fine_list).squeeze() + np.array(nmi_leiden_fine_list).squeeze() + np.array(isolated_labels_fine_list).squeeze() + np.array(silhouette_label_fine_list).squeeze())

erg_dict_hom['full_genes']['Total'] =  0.4 * erg_dict_hom['full_genes']['Species mixing'] + 0.6 * erg_dict_hom['full_genes']['Biology conservation']

erg_dict_hom['full_genes']['BAS Label transfer'] =  np.array(bas_fine_list).squeeze()

with open(data_path + dataset + "_erg_dict_hom.json", "w") as f:
    json.dump(erg_dict_hom, f, cls=NumpyEncoder)       
    

In [None]:
erg_dict_mean_hom = {}
erg_dict_std_hom = {}

for i in range(8):

    silhouette_batch_fine_list = []
    graph_connectivity_fine_list = []
    kbet_fine_list = []
    ari_leiden_fine_list = []
    nmi_leiden_fine_list = []
    isolated_labels_fine_list = []
    silhouette_label_fine_list = []
    bas_fine_list = []  

    bio_cons_fine_list = []
    species_mix_fine_list = []
    total_list = [] 
     
    key = str((i+1)*200)
    erg_dict_mean_hom[key] = {}
    erg_dict_std_hom[key] = {}
    
    for j in range(10):
    
        model_key = str(j)+'_'+str((i+1)*200)

        silhouette_batch_fine_list.append(erg_dict_hom[model_key]['bASW'])
        graph_connectivity_fine_list.append(erg_dict_hom[model_key]['GC'])
        kbet_fine_list.append(erg_dict_hom[model_key]['kBET'])
        ari_leiden_fine_list.append(erg_dict_hom[model_key]['ARI'])
        nmi_leiden_fine_list.append(erg_dict_hom[model_key]['NMI'])
        isolated_labels_fine_list.append(erg_dict_hom[model_key]['isoF1'])
        silhouette_label_fine_list.append(erg_dict_hom[model_key]['cASW'])
        bas_fine_list.append(erg_dict_hom[model_key]['BAS Label transfer'])

        bio_cons_fine_list.append(erg_dict_hom[model_key]['Biology conservation'])
        species_mix_fine_list.append(erg_dict_hom[model_key]['Species mixing'])
        total_list.append(erg_dict_hom[model_key]['Total'])     
        
    erg_dict_mean_hom[key]['bASW'] = np.mean(np.array(silhouette_batch_fine_list))
    erg_dict_mean_hom[key]['kBET'] = np.mean(np.array(kbet_fine_list))
    erg_dict_mean_hom[key]['GC'] = np.mean(np.array(graph_connectivity_fine_list))
    erg_dict_mean_hom[key]['ARI'] = np.mean(np.array(ari_leiden_fine_list)) 
    erg_dict_mean_hom[key]['NMI'] = np.mean(np.array(nmi_leiden_fine_list)) 
    erg_dict_mean_hom[key]['isoF1'] = np.mean(np.array(isolated_labels_fine_list))
    erg_dict_mean_hom[key]['cASW'] = np.mean(np.array(silhouette_label_fine_list)) 
    erg_dict_mean_hom[key]['BAS'] = np.mean(np.array(bas_fine_list)) 
    erg_dict_mean_hom[key]['Biology'] = np.mean(np.array(bio_cons_fine_list)) 
    erg_dict_mean_hom[key]['Species'] = np.mean(np.array(species_mix_fine_list)) 
    erg_dict_mean_hom[key]['Total'] = np.mean(np.array(total_list))  
    
    erg_dict_std_hom[key]['bASW'] = np.std(np.array(silhouette_batch_fine_list)) 
    erg_dict_std_hom[key]['kBET'] = np.std(np.array(kbet_fine_list)) 
    erg_dict_std_hom[key]['GC'] = np.std(np.array(graph_connectivity_fine_list)) 
    erg_dict_std_hom[key]['ARI'] = np.std(np.array(ari_leiden_fine_list)) 
    erg_dict_std_hom[key]['NMI'] = np.std(np.array(nmi_leiden_fine_list)) 
    erg_dict_std_hom[key]['isoF1'] = np.std(np.array(isolated_labels_fine_list)) 
    erg_dict_std_hom[key]['cASW'] = np.std(np.array(silhouette_label_fine_list)) 
    erg_dict_std_hom[key]['BAS'] = np.std(np.array(bas_fine_list)) 
    erg_dict_std_hom[key]['Biology'] = np.std(np.array(bio_cons_fine_list)) 
    erg_dict_std_hom[key]['Species'] = np.std(np.array(species_mix_fine_list)) 
    erg_dict_std_hom[key]['Total'] = np.std(np.array(total_list))      
    

key = str(0)
erg_dict_mean_hom[key] = {}
erg_dict_std_hom[key] = {}

model_key = 'full_genes'

silhouette_batch_fine_list = erg_dict_hom[model_key]['bASW']
graph_connectivity_fine_list = erg_dict_hom[model_key]['GC']
kbet_fine_list = erg_dict_hom[model_key]['kBET']
ari_leiden_fine_list = erg_dict_hom[model_key]['ARI']
nmi_leiden_fine_list = erg_dict_hom[model_key]['NMI']
isolated_labels_fine_list = erg_dict_hom[model_key]['isoF1']
silhouette_label_fine_list = erg_dict_hom[model_key]['cASW']
bas_fine_list = erg_dict_hom[model_key]['BAS Label transfer']

bio_cons_fine_list = erg_dict_hom[model_key]['Biology conservation']
species_mix_fine_list = erg_dict_hom[model_key]['Species mixing']
total_list = erg_dict_hom[model_key]['Total']     
    
erg_dict_mean_hom[key]['bASW'] = np.mean(np.array(silhouette_batch_fine_list))
erg_dict_mean_hom[key]['kBET'] = np.mean(np.array(kbet_fine_list))
erg_dict_mean_hom[key]['GC'] = np.mean(np.array(graph_connectivity_fine_list))
erg_dict_mean_hom[key]['ARI'] = np.mean(np.array(ari_leiden_fine_list)) 
erg_dict_mean_hom[key]['NMI'] = np.mean(np.array(nmi_leiden_fine_list)) 
erg_dict_mean_hom[key]['isoF1'] = np.mean(np.array(isolated_labels_fine_list))
erg_dict_mean_hom[key]['cASW'] = np.mean(np.array(silhouette_label_fine_list)) 
erg_dict_mean_hom[key]['BAS'] = np.mean(np.array(bas_fine_list)) 
erg_dict_mean_hom[key]['Biology'] = np.mean(np.array(bio_cons_fine_list)) 
erg_dict_mean_hom[key]['Species'] = np.mean(np.array(species_mix_fine_list)) 
erg_dict_mean_hom[key]['Total'] = np.mean(np.array(total_list))  

erg_dict_std_hom[key]['bASW'] = np.std(np.array(silhouette_batch_fine_list)) 
erg_dict_std_hom[key]['kBET'] = np.std(np.array(kbet_fine_list)) 
erg_dict_std_hom[key]['GC'] = np.std(np.array(graph_connectivity_fine_list)) 
erg_dict_std_hom[key]['ARI'] = np.std(np.array(ari_leiden_fine_list)) 
erg_dict_std_hom[key]['NMI'] = np.std(np.array(nmi_leiden_fine_list)) 
erg_dict_std_hom[key]['isoF1'] = np.std(np.array(isolated_labels_fine_list)) 
erg_dict_std_hom[key]['cASW'] = np.std(np.array(silhouette_label_fine_list)) 
erg_dict_std_hom[key]['BAS'] = np.std(np.array(bas_fine_list)) 
erg_dict_std_hom[key]['Biology'] = np.std(np.array(bio_cons_fine_list)) 
erg_dict_std_hom[key]['Species'] = np.std(np.array(species_mix_fine_list)) 
erg_dict_std_hom[key]['Total'] = np.std(np.array(total_list))      


key = str(1808)
erg_dict_mean_hom[key] = {}
erg_dict_std_hom[key] = {}
model_key = 'no_shared'

silhouette_batch_fine_list = erg_dict_hom[model_key]['bASW']
graph_connectivity_fine_list = erg_dict_hom[model_key]['GC']
kbet_fine_list = erg_dict_hom[model_key]['kBET']
ari_leiden_fine_list = erg_dict_hom[model_key]['ARI']
nmi_leiden_fine_list = erg_dict_hom[model_key]['NMI']
isolated_labels_fine_list = erg_dict_hom[model_key]['isoF1']
silhouette_label_fine_list = erg_dict_hom[model_key]['cASW']
bas_fine_list = erg_dict_hom[model_key]['BAS Label transfer']

bio_cons_fine_list = erg_dict_hom[model_key]['Biology conservation']
species_mix_fine_list = erg_dict_hom[model_key]['Species mixing']
total_list = erg_dict_hom[model_key]['Total']     
    
erg_dict_mean_hom[key]['bASW'] = np.mean(np.array(silhouette_batch_fine_list))
erg_dict_mean_hom[key]['kBET'] = np.mean(np.array(kbet_fine_list))
erg_dict_mean_hom[key]['GC'] = np.mean(np.array(graph_connectivity_fine_list))
erg_dict_mean_hom[key]['ARI'] = np.mean(np.array(ari_leiden_fine_list)) 
erg_dict_mean_hom[key]['NMI'] = np.mean(np.array(nmi_leiden_fine_list)) 
erg_dict_mean_hom[key]['isoF1'] = np.mean(np.array(isolated_labels_fine_list))
erg_dict_mean_hom[key]['cASW'] = np.mean(np.array(silhouette_label_fine_list)) 
erg_dict_mean_hom[key]['BAS'] = np.mean(np.array(bas_fine_list)) 
erg_dict_mean_hom[key]['Biology'] = np.mean(np.array(bio_cons_fine_list)) 
erg_dict_mean_hom[key]['Species'] = np.mean(np.array(species_mix_fine_list)) 
erg_dict_mean_hom[key]['Total'] = np.mean(np.array(total_list))  

erg_dict_std_hom[key]['bASW'] = np.std(np.array(silhouette_batch_fine_list)) 
erg_dict_std_hom[key]['kBET'] = np.std(np.array(kbet_fine_list)) 
erg_dict_std_hom[key]['GC'] = np.std(np.array(graph_connectivity_fine_list)) 
erg_dict_std_hom[key]['ARI'] = np.std(np.array(ari_leiden_fine_list)) 
erg_dict_std_hom[key]['NMI'] = np.std(np.array(nmi_leiden_fine_list)) 
erg_dict_std_hom[key]['isoF1'] = np.std(np.array(isolated_labels_fine_list)) 
erg_dict_std_hom[key]['cASW'] = np.std(np.array(silhouette_label_fine_list)) 
erg_dict_std_hom[key]['BAS'] = np.std(np.array(bas_fine_list)) 
erg_dict_std_hom[key]['Biology'] = np.std(np.array(bio_cons_fine_list)) 
erg_dict_std_hom[key]['Species'] = np.std(np.array(species_mix_fine_list)) 
erg_dict_std_hom[key]['Total'] = np.std(np.array(total_list)) 

df_mean_hom = pd.DataFrame(erg_dict_mean_hom).T
df_std_hom = pd.DataFrame(erg_dict_std_hom).T      
