In [None]:

import numpy as np
import muon as mu
import scanpy as sc
import anndata as ad
import os
import random
from models import scSpecies
import torch
import json
import pandas as pd

from create_datasets.preprocessing import set_random_seed
from create_datasets.preprocessing import create_mdata

%load_ext autoreload
%autoreload 2

set_random_seed(1234)

path = os.path.abspath('').replace('\\', '/')+'/'
data_path = path+'dataset/'
save_path = os.path.abspath('').replace('\\', '/')+'/results/'

device = ("cuda" if torch.cuda.is_available() else "mps")

In [None]:
h_dict_list = [
                {'k_neigh': 25, 'alignment': 'latent'}, 
                {'k_neigh': 1, 'alignment': 'latent'}, 
                {'k_neigh': 250, 'alignment': 'latent'}, 
                {'k_neigh': 25, 'alignment': 'inter'}
              ]

dataset_list = [
                {'dataset': "liver_human", 'context_key': 'mouse', 'target_key': 'human', 'load_key': 'liver'},  
                {'dataset': "glio", 'context_key': 'mouse', 'target_key': 'human', 'load_key': 'glio'}, 
                {'dataset': "adipose", 'context_key': 'mouse', 'target_key': 'human', 'load_key': 'adipose'}, 
                {'dataset': "liver_mouseNafld", 'context_key': 'mouse', 'target_key': 'mouseNafld', 'load_key': 'liver'},                
               ]

for dataset_dict in dataset_list:
    dataset = dataset_dict['dataset']
    context_key = dataset_dict['context_key']
    target_key = dataset_dict['target_key']
    load_key = dataset_dict['load_key']

    for params in h_dict_list:

        k_neigh = params['k_neigh']
        alignment = params['alignment'] 
        
        for i in range(10):
            save_key = load_key +'_'+ context_key +'_'+ target_key +'_'+ str(k_neigh) +'_'+ alignment +'_'+ str(i) 

            print(dataset, i, k_neigh, alignment)

            mdata = mu.read_h5mu(path+f"dataset/{load_key}.h5mu")

            model = scSpecies(device, 
                              mdata, 
                              path, 
                              k_neigh = k_neigh,
                              alignment = alignment,
                              context_dataset_key = context_key, 
                              target_dataset_key = target_key, 
                              random_seed = 1234*i
                              )

            model.train_context(30, save_model=False)
            model.eval_context()

            if dataset == "liver_human" and k_neigh == 25 and alignment == 'latent':
                model.train_target(30, save_model=False, track_prototypes=True)
                model.eval_target()    
                np.save(path+'save_adata/'+save_key+"_sim_metric.npy", np.array(model.sim_metric))

            else:
                model.train_target(30, save_model=False)
                model.eval_target()  

            model.save_params(save_key, save='both', name='')

            model.eval_label_transfer([('cell_type_coarse', 'cell_type_coarse'), ('cell_type_fine', 'cell_type_fine')])

            #adata = ad.AnnData(np.concat([model.mdata[context_key].obsm['latent_mu'], model.mdata[target_key].obsm['latent_mu']]))
            #obs = pd.concat([model.mdata[context_key].obs, model.mdata[target_key].obs])
            #obs['species'] = np.array(['context'] * model.mdata[context_key].n_obs + ['target'] * model.mdata[target_key].n_obs)
            #adata.obs = obs

            #adata.uns['ind_nns_aligned_latent_space'] = model.mdata[target_key].obsm['ind_nns_aligned_latent_space']
            #adata.uns['nlog_likeli_nns_aligned_latent_space'] = model.mdata[target_key].obsm['nlog_likeli_nns_aligned_latent_space']
            #adata.write(path+'save_adata/'+save_key+'.h5ad')
            model.mdata.write(save_path+'/'+save_key+'.h5mu')

In [None]:
h_dict_list = [
                {'k_neigh': 25, 'alignment': 'latent'}, 
              ]

dataset_list = [
                {'dataset': "liver_human", 'context_key': 'mouse', 'target_key': 'human', 'load_key': 'liver'},                 
               ]

for dataset_size in [1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000, 20000, 50000]:

    for dataset_dict in dataset_list:
        dataset = dataset_dict['dataset']
        context_key = dataset_dict['context_key']
        target_key = dataset_dict['target_key']
        load_key = dataset_dict['load_key']

        for params in h_dict_list:

            k_neigh = params['k_neigh']
            alignment = params['alignment'] 
            
            for i in range(10):
                save_key = load_key +'_'+ context_key +'_'+ target_key +'_'+ str(k_neigh) +'_'+ alignment +'_'+ str(i) +'_size_'+ str(dataset_size)

                mdata = mu.read_h5mu(path+f"dataset/{load_key}.h5mu")

                sc.pp.subsample(mdata.mod['human'], n_obs=dataset_size, random_state=1234*i)

                if dataset_size > 10000:
                    train_epochs = 30
                    top_percent = 20

                elif dataset_size < 10000:
                    train_epochs = 60
                    top_percent = 50

                model = scSpecies(device, 
                                mdata, 
                                path, 
                                k_neigh = k_neigh,
                                alignment = alignment,
                                top_percent = top_percent,
                                context_dataset_key = context_key, 
                                target_dataset_key = target_key, 
                                random_seed = 1234*i
                                )

                model.train_context(30, save_model=False)
                model.eval_context()

                model.train_target(dataset_size, save_model=False)
                model.eval_target()   

                model.save_params(save_key, save='both', name='')

                model.eval_label_transfer([('cell_type_coarse', 'cell_type_coarse'), ('cell_type_fine', 'cell_type_fine')])
                
                #adata = ad.AnnData(np.concat([model.mdata[context_key].obsm['latent_mu'], model.mdata[target_key].obsm['latent_mu']]))
                #obs = pd.concat([model.mdata[context_key].obs, model.mdata[target_key].obs])
                #obs['species'] = np.array(['context'] * model.mdata[context_key].n_obs + ['target'] * model.mdata[target_key].n_obs)
                #adata.obs = obs

                #adata.uns['ind_nns_aligned_latent_space'] = model.mdata[target_key].obsm['ind_nns_aligned_latent_space']
                #adata.uns['nlog_likeli_nns_aligned_latent_space'] = model.mdata[target_key].obsm['nlog_likeli_nns_aligned_latent_space']

                #adata.write(path+'save_adata/'+save_key+'.h5ad')
                model.mdata.write(save_path+'/'+save_key+'.h5mu')

In [None]:
h_dict_list = [
                {'k_neigh': 25, 'alignment': 'latent'}, 
              ]

dataset_list = [
                {'dataset': "liver_human", 'context_key': 'mouse', 'target_key': 'human', 'load_key': 'liver'},                 
               ]

for reduced_genes in [200, 400, 600, 800, 1000, 1200, 1400, 1600]:

    for dataset_dict in dataset_list:
        dataset = dataset_dict['dataset']
        context_key = dataset_dict['context_key']
        target_key = dataset_dict['target_key']
        load_key = dataset_dict['load_key']

        for params in h_dict_list:

            k_neigh = params['k_neigh']
            alignment = params['alignment'] 
            
            for i in range(10):
                save_key = load_key +'_'+ context_key +'_'+ target_key +'_'+ str(k_neigh) +'_'+ alignment +'_'+ str(i) +'_reduced_'+ str(reduced_genes)

                mdata = mu.read_h5mu(path+f"dataset/{load_key}.h5mu")

                v_ctx = mdata.mod[context_key].var['human_gene_names'].astype(str).to_numpy()
                v_tgt = mdata.mod[target_key].var['human_gene_names'].astype(str).to_numpy()

                shared_genes = np.intersect1d(v_ctx, v_tgt)

                perm = np.random.permutation(len(shared_genes))
                rem = np.asarray(shared_genes[perm[:reduced_genes]], dtype=str)

                context_ind = np.setdiff1d(np.arange(mdata.mod[context_key].n_vars), np.array([np.where(np.isin(v_ctx, r))[0][0] for r in rem]))
                target_ind = np.setdiff1d(np.arange(mdata.mod[target_key].n_vars), np.array([np.where(np.isin(v_tgt, r))[0][0] for r in rem]))

                preprocess = create_mdata(mdata.mod[context_key][:, context_ind], 
                                        context_batch_key='batch', 
                                        context_cell_key='cell_type_fine', 
                                        context_dataset_name=context_key, 
                                        context_gene_naming_convention=target_key)

                preprocess.setup_target_adata(mdata.mod[target_key][:, target_ind], 
                                            target_batch_key='batch', 
                                            target_cell_key='cell_type_fine', 
                                            target_dataset_name=target_key, 
                                            target_gene_naming_convention=target_key)

                model = scSpecies(device, 
                                mdata, 
                                path, 
                                k_neigh = k_neigh,
                                alignment = alignment,
                                top_percent = top_percent,
                                context_dataset_key = context_key, 
                                target_dataset_key = target_key, 
                                random_seed = 1234*i
                                )

                model.train_context(30, save_model=False)
                model.eval_context()

                model.train_target(dataset_size, save_model=False)
                model.eval_target()   

                model.save_params(save_key, save='both', name='')

                model.eval_label_transfer([('cell_type_coarse', 'cell_type_coarse'), ('cell_type_fine', 'cell_type_fine')])
                
                #adata = ad.AnnData(np.concat([model.mdata[context_key].obsm['latent_mu'], model.mdata[target_key].obsm['latent_mu']]))
                #obs = pd.concat([model.mdata[context_key].obs, model.mdata[target_key].obs])
                #obs['species'] = np.array(['context'] * model.mdata[context_key].n_obs + ['target'] * model.mdata[target_key].n_obs)
                #adata.obs = obs

                #adata.uns['ind_nns_aligned_latent_space'] = model.mdata[target_key].obsm['ind_nns_aligned_latent_space']
                #adata.uns['nlog_likeli_nns_aligned_latent_space'] = model.mdata[target_key].obsm['nlog_likeli_nns_aligned_latent_space']
                #adata.uns['context_ind'] = context_ind
                #adata.uns['target_ind'] = target_ind
                #adata.write(path+'save_adata/'+save_key+'.h5ad')
                model.mdata.write(save_path+'/'+save_key+'.h5mu')