In [None]:
import os
import time
import json
import random
import gseapy as gp
import  re
import math
import seaborn as sns
import textwrap
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import muon as mu
import scanpy as sc
import anndata as ad
import torch
import torch.nn.functional as F
from scipy.stats import ttest_ind, entropy
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
from sklearn.preprocessing import OneHotEncoder
from sklearn.neighbors import NearestNeighbors
from statsmodels.stats.multitest import multipletests

from create_datasets.preprocessing import create_mdata, set_random_seed
from models import scSpecies

from scipy.stats import spearmanr
from scipy.stats import pearsonr
from scipy.stats import kendalltau

matplotlib.rcParams['font.family'] = 'Helvetica'

%load_ext autoreload
%autoreload 2

import warnings
warnings.filterwarnings('ignore')

set_random_seed(1234)
np.random.seed(1234)

path = os.path.abspath('').replace('\\', '/')+'/'
data_path = path+'dataset/'
save_path = os.path.abspath('').replace('\\', '/')+'/results/'

# Data level analysis

In [None]:
device = "mps"
dataset = "liver"
context_key = 'mouse'
target_key = 'mouseNafld'
load_key = 'liver'    

save_key = dataset+'_'+context_key+'_'+target_key+'_'

i=0
set_random_seed(1234)

mdata = mu.read_h5mu(data_path+dataset+".h5mu")

h_dict = {'k_neigh': 25,  'alignment': 'inter',  'top_percent': 20, 'latent_dim':10}

scSpecies_model = scSpecies(device, 
                mdata, 
                path,
                context_dataset_key = context_key, 
                target_dataset_key = target_key,       
                random_seed = 1234,
                use_lib_enc = True,
                **h_dict
                )

scSpecies_model.train_context(40, save_model=False)
scSpecies_model.eval_context()

scSpecies_model.train_target(40, track_prototypes=False, save_model=False)
scSpecies_model.eval_target()

scSpecies_model.get_representation('context', save_libsize=True) 
scSpecies_model.get_representation('target', save_libsize=True) 

## 1) Log-fold change analysis

In [None]:
def compute_logfold_change(model, context_cell_key, target_cell_key, eps=1e-6, lfc_delta=1, samples=50000, b_s=128, confidence_level=0.9):
    model.mdata.mod[model.context_dataset_key].uns['lfc_delta'] = lfc_delta
    model.context_decoder.eval()   
    model.target_decoder.eval()    
    model.context_encoder_inner.eval()   
    model.target_encoder_inner.eval() 
    model.context_encoder_outer.eval()   
    model.target_encoder_outer.eval() 
    model.context_lib_encoder.eval()   
    model.target_lib_encoder.eval()         
                        
    target_ind = np.array(model.target_param_dict['homologous_genes'])
    target_gene_names = model.mdata.mod[model.target_dataset_key].var_names.to_numpy()[target_ind]

    context_cell_labels = model.mdata.mod[model.context_dataset_key].obs[context_cell_key].to_numpy()
    context_cell_types = np.unique(context_cell_labels)
    context_cell_index = {c : np.where(context_cell_labels == c)[0] for c in context_cell_types}

    target_cell_labels = model.mdata.mod[model.target_dataset_key].obs[target_cell_key].to_numpy()
    target_cell_types = np.unique(target_cell_labels)
    target_cell_index = {c : np.where(target_cell_labels == c)[0] for c in target_cell_types}

    context_batch_key = model.mdata.mod[model.context_dataset_key].uns['dataset_batch_key']
    target_batch_key = model.mdata.mod[model.target_dataset_key].uns['dataset_batch_key']
    
    context_batch_labels = model.mdata.mod[model.context_dataset_key].obs[context_batch_key].to_numpy().reshape(-1, 1)
    target_batch_labels = model.mdata.mod[model.target_dataset_key].obs[target_batch_key].to_numpy().reshape(-1, 1)

    context_enc = OneHotEncoder()
    context_enc.fit(context_batch_labels)

    target_enc = OneHotEncoder()
    target_enc.fit(target_batch_labels)

    context_batches = {c : model.mdata.mod[model.context_dataset_key][model.mdata.mod[model.context_dataset_key].obs[context_cell_key] == c].obs[context_batch_key].value_counts() > 3 for c in context_cell_types}
    context_batches = {c : context_batches[c][context_batches[c]].index.to_numpy() for c in context_cell_types}
    context_batches = {c : context_enc.transform(context_batches[c].reshape(-1, 1)).toarray().astype(np.float32)  for c in context_cell_types}
    context_batches['unknown'] = context_enc.transform(np.unique(context_batch_labels).reshape(-1, 1)).toarray().astype(np.float32)

    if target_cell_key == None:
        joint_cell_types = context_cell_types

    else:
        target_cell_labels = model.mdata.mod[model.target_dataset_key].obs[target_cell_key].to_numpy()
        target_cell_types = np.unique(target_cell_labels)
        joint_cell_types = np.intersect1d(context_cell_types, target_cell_types, return_indices=True)[0]
        target_batches = {c : model.mdata.mod[model.target_dataset_key][model.mdata.mod[model.target_dataset_key].obs[target_cell_key] == c].obs[target_batch_key].value_counts() > 3 for c in target_cell_types}
        target_batches = {c : target_batches[c][target_batches[c]].index.to_numpy() for c in target_cell_types}
        target_batches = {c : target_enc.transform(target_batches[c].reshape(-1, 1)).toarray().astype(np.float32)  for c in target_cell_types}
        target_batches['unknown'] = target_enc.transform(np.unique(target_batch_labels).reshape(-1, 1)).toarray().astype(np.float32)

    lfc_dict = {}
    random_perm = np.random.permutation(len(target_gene_names))

    for cell_type in joint_cell_types:

        adata_context = model.mdata.mod[model.context_dataset_key][context_cell_index[cell_type]]
        adata_target = model.mdata.mod[model.target_dataset_key][target_cell_index[cell_type]]
        
        filtered_data_ind, _ = model.filter_outliers(adata_context.obsm['l_mu'], confidence_level=confidence_level)
        adata_context = adata_context[filtered_data_ind]

        filtered_data_ind, _ = model.filter_outliers(adata_target.obsm['l_mu'], confidence_level=confidence_level)
        adata_target = adata_target[filtered_data_ind]      

        latent_target = adata_target.obsm['l_mu']
        latent_context = adata_context.obsm['l_mu']
        nn = NearestNeighbors(n_neighbors=25, metric='cosine', algorithm='auto')
        nn.fit(latent_context)
        distances, indices = nn.kneighbors(latent_target)
        adata_target.obsm['cell_context_ind'] = indices

        steps = np.ceil(adata_target.n_obs/b_s).astype(int)    
        iterations = int(np.ceil(samples/adata_target.n_obs))

        with torch.no_grad():
            logfold_list_rho = []    
            rho_mouse = []     
            mu_mouse = []        
            rho_human = []     
            mu_human = []                    

            logfold_list_rho_random = []    

            for iter in range(iterations):
                for step in range(steps):   
                    batch_adata = adata_target[step*b_s:(step+1)*b_s]
                    context_cell_type = batch_adata.obs[batch_adata.uns['dataset_cell_key']].to_numpy()
                    target_cell_type = batch_adata.obs[batch_adata.uns['dataset_cell_key']].to_numpy() 

                    context_labels = np.concatenate([context_batches[c] for c in context_cell_type])
                    target_labels = np.concatenate([target_batches[c] for c in target_cell_type])
                
                    context_labels = torch.from_numpy(context_labels).to(model.device)
                    target_labels = torch.from_numpy(target_labels).to(model.device)            

                    context_ind_batch = np.array([np.shape(context_batches[c])[0] for c in context_cell_type])
                    target_ind_batch = np.array([np.shape(target_batches[c])[0] for c in target_cell_type])

                    shape = np.shape(batch_adata.obsm['latent_sig'])

                    z = np.float32(batch_adata.obsm['latent_mu'] + batch_adata.obsm['latent_sig'] * np.random.rand(shape[0], shape[1])) 
                    target_l = np.exp(np.float32(batch_adata.obsm['l_mu'] + batch_adata.obsm['l_sig'] * np.random.rand(shape[0], 1)))
                    neigh_ind = batch_adata.obsm['cell_context_ind']
                    
                    context_l = np.exp(np.float32(adata_context.obsm['l_mu'][neigh_ind] + adata_context.obsm['l_sig'][neigh_ind] * np.random.rand(shape[0], 25, 1)))
                    context_l = context_l.mean(axis=1)

                    context_z = np.concatenate([np.tile(z[j], (i, 1)) for j, i in enumerate(context_ind_batch)])
                    target_z = np.concatenate([np.tile(z[j], (i, 1)) for j, i in enumerate(target_ind_batch)])

                    context_z = torch.from_numpy(context_z).to(model.device)
                    target_z = torch.from_numpy(target_z).to(model.device)

                    context_rho = model.context_decoder.decode_homologous(context_z, context_labels).cpu().numpy()
                    context_rho = model.average_slices(context_rho, context_ind_batch) 

                    target_rho = model.target_decoder.decode_homologous(target_z, target_labels).cpu().numpy()
                    target_rho = model.average_slices(target_rho, target_ind_batch)

                    context_mu = context_rho * context_l
                    target_mu = target_rho * target_l

                    rho_mouse.append(context_rho)
                    mu_mouse.append(context_mu)
                    rho_human.append(target_rho)
                    mu_human.append(target_mu)

                    logfold_list_rho.append(np.log2(target_rho+eps) - np.log2(context_rho+eps))
                    logfold_list_rho_random.append(np.log2(target_rho+eps) - np.log2(context_rho[:, random_perm]+eps))


        lfc_dict[cell_type] = pd.DataFrame(0, index=target_gene_names, columns=[
            'rho_median_context', 'mu_median_context', 'rho_median_target', 'mu_median_target', 'lfc', 'p', 'lfc_rand', 'p_rand'])

        rho_mouse = np.concatenate(rho_mouse)
        mu_mouse = np.concatenate(mu_mouse)
        rho_human = np.concatenate(rho_human)
        mu_human = np.concatenate(mu_human)

        lfc_dict[cell_type]['rho_median_context'] = np.median(rho_mouse, axis=0)
        lfc_dict[cell_type]['mu_median_context'] = np.median(mu_mouse, axis=0)
        lfc_dict[cell_type]['rho_median_target'] = np.median(rho_human, axis=0)
        lfc_dict[cell_type]['mu_median_target'] = np.median(mu_human, axis=0)        

        logfold_list_rho = np.concatenate(logfold_list_rho)

        lfc_dict[cell_type]['lfc'] = np.median(logfold_list_rho, axis=0)
        lfc_dict[cell_type]['p'] = np.sum(np.where(np.abs(logfold_list_rho)>lfc_delta, 1, 0), axis=0) / np.shape(logfold_list_rho)[0]
        logfold_list_rho_random = np.concatenate(logfold_list_rho_random)

        lfc_dict[cell_type]['lfc_rand'] = np.median(logfold_list_rho_random, axis=0)
        lfc_dict[cell_type]['p_rand']  = np.sum(np.where(np.abs(logfold_list_rho_random)>lfc_delta, 1, 0), axis=0) / np.shape(logfold_list_rho_random)[0]

    return lfc_dict
        
lfc_dict = compute_logfold_change(scSpecies_model, 'cell_type_fine', 'cell_type_fine', eps=1e-6, lfc_delta=0.4, samples=50000, b_s=128, confidence_level=0.95)         


In [None]:
lfc_delta = 1
prob_delta = 0.9

cell_types = list(lfc_dict.keys())
df_lfc = pd.DataFrame({ct: lfc_dict[ct]['lfc'] for ct in lfc_dict.keys()})
df_prob = pd.DataFrame({ct: lfc_dict[ct]['p'] for ct in lfc_dict.keys()})

df_lfc_random = pd.DataFrame({ct: lfc_dict[ct]['lfc_rand'] for ct in lfc_dict.keys()})
df_prob_random = pd.DataFrame({ct: lfc_dict[ct]['p_rand'] for ct in lfc_dict.keys()})

n_cell_types = len(cell_types)
n_cols = 4  
n_rows = int(np.ceil(n_cell_types / n_cols))
fig, axs = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 3*n_rows), squeeze=False)
axs = axs.flatten()

ge_one = []
ge_prob = []
up_reg = []
down_reg = []

ge_one_random = []
ge_prob_random = []
up_reg_random = []
down_reg_random = []

for i, cell in enumerate(cell_types):

    greater_than_one = np.round(((df_lfc[cell].abs() > lfc_delta).mean()*100), 1)
    greater_than_one_prob = np.round(((df_lfc[cell].abs() >lfc_delta) & (df_prob[cell].abs() > prob_delta)).mean()*100, 1)
    up = np.round(((df_lfc[cell] > lfc_delta) & (df_prob[cell].abs() > prob_delta)).mean()*100, 1)
    down = np.round(((df_lfc[cell] < -lfc_delta) & (df_prob[cell].abs() > prob_delta)).mean()*100, 1)
    
    ge_one.append(greater_than_one)
    ge_prob.append(greater_than_one_prob)
    up_reg.append(up)
    down_reg.append(down)

    ge_one_random.append(np.round(((df_lfc_random[cell].abs() > lfc_delta).mean()*100), 1))
    ge_prob_random.append(np.round(((df_lfc_random[cell].abs() > lfc_delta) & (df_prob_random[cell].abs() > prob_delta)).mean()*100, 1))
    up_reg_random.append(np.round(((df_lfc_random[cell] > lfc_delta) & (df_prob_random[cell].abs() > prob_delta)).mean()*100, 1))
    down_reg_random.append(np.round(((df_lfc_random[cell] < -lfc_delta) & (df_prob_random[cell].abs() > prob_delta)).mean()*100, 1))
    
    ax = axs[i]

    colors = []
    for l, p in zip(df_lfc[cell], df_prob[cell]):
        if abs(l) <= 1:
            colors.append('grey')
        else:
            if l > 1:

                colors.append('red' if p > prob_delta else 'lightcoral')
            elif l < -1:

                colors.append('blue' if p > prob_delta else 'lightblue')
    
    ax.scatter(df_lfc[cell], df_prob[cell], c=colors, s=12, edgecolor='k')
    ax.set_xlabel('Log Fold Change')
    ax.set_ylabel('Probability')

    ax.set_title(f"{cell} |LFC|>1, with p>0.9: {greater_than_one_prob}", pad=10)

    ax.axhline(prob_delta, color='black', linestyle='--', linewidth=1.2)
    ax.axvline(-1, color='black', linestyle='--', linewidth=1.2)
    ax.axvline(1, color='black', linestyle='--', linewidth=1.2)

    up_subset = df_lfc[cell][(df_lfc[cell] > lfc_delta) & (df_prob[cell].abs() > prob_delta)]
    down_subset = df_lfc[cell][(df_lfc[cell] < -lfc_delta) & (df_prob[cell].abs() > prob_delta)]
    
    top_up = up_subset.sort_values(ascending=False).head(5)
    top_down = down_subset.sort_values(ascending=True).head(5)

    up_text = "Upregulated:\n" + "\n".join([f"{j+1}. {gene}" for j, gene in enumerate(top_up.index)])
    down_text = "Downregulated:\n" + "\n".join([f"{j+1}. {gene}" for j, gene in enumerate(top_down.index)])

    ax.text(0.025, 0.54, up_text, transform=ax.transAxes, verticalalignment='top',
            fontsize=10, bbox=dict(boxstyle="round", alpha=0.3, facecolor="white"))
    ax.text(0.98, 0.54, down_text, transform=ax.transAxes, verticalalignment='top',
            horizontalalignment='right', fontsize=10,
            bbox=dict(boxstyle="round", alpha=0.3, facecolor="white"))

for j in range(i+1, len(axs)):
    fig.delaxes(axs[j])

# Adjust layout to leave space for the suptitle
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.suptitle(f'Differential gene expression analysis of a {context_key} context dataset aligned with a {target_key} target dataset.\n' 
             f'Median |LFC|>1: {np.round(np.mean(ge_one),1)}%, with p>0.9: {np.round(np.mean(ge_prob),1)}. Up regulated: {np.round(np.mean(up_reg),1)}%, down regulated: {np.round(np.mean(down_reg),1)}%.\n'
             f'Permuted data, median |LFC|>1: {np.round(np.mean(ge_one_random),1)}%, with p>0.9: {np.round(np.mean(ge_prob_random),1)}. Up regulated: {np.round(np.mean(up_reg_random),1)}%, down regulated: {np.round(np.mean(down_reg_random),1)}%.',
             y=0.99, fontsize=16)

plt.savefig(save_path+save_key+"DGE.pdf", bbox_inches='tight')

plt.show()

## 2) Correlation of LFC values

In [None]:
def compute_logfold_change(model, context_cell_key, target_cell_key, eps=1e-6, lfc_delta=1, samples=50000, b_s=128, confidence_level=0.9):
    model.mdata.mod[model.context_dataset_key].uns['lfc_delta'] = lfc_delta
    model.context_decoder.eval()   
    model.target_decoder.eval()    
    model.context_encoder_inner.eval()   
    model.target_encoder_inner.eval() 
    model.context_encoder_outer.eval()   
    model.target_encoder_outer.eval() 
    model.context_lib_encoder.eval()   
    model.target_lib_encoder.eval()         
                        
    target_ind = np.array(model.target_param_dict['homologous_genes'])
    target_gene_names = model.mdata.mod[model.target_dataset_key].var_names.to_numpy()[target_ind]

    context_cell_labels = model.mdata.mod[model.context_dataset_key].obs[context_cell_key].to_numpy()
    context_cell_types = np.unique(context_cell_labels)
    context_cell_index = {c : np.where(context_cell_labels == c)[0] for c in context_cell_types}

    target_cell_labels = model.mdata.mod[model.target_dataset_key].obs[target_cell_key].to_numpy()
    target_cell_types = np.unique(target_cell_labels)
    target_cell_index = {c : np.where(target_cell_labels == c)[0] for c in target_cell_types}

    context_batch_key = model.mdata.mod[model.context_dataset_key].uns['dataset_batch_key']
    target_batch_key = model.mdata.mod[model.target_dataset_key].uns['dataset_batch_key']
    
    context_batch_labels = model.mdata.mod[model.context_dataset_key].obs[context_batch_key].to_numpy().reshape(-1, 1)
    target_batch_labels = model.mdata.mod[model.target_dataset_key].obs[target_batch_key].to_numpy().reshape(-1, 1)

    context_enc = OneHotEncoder()
    context_enc.fit(context_batch_labels)

    target_enc = OneHotEncoder()
    target_enc.fit(target_batch_labels)

    context_batches = {c : model.mdata.mod[model.context_dataset_key][model.mdata.mod[model.context_dataset_key].obs[context_cell_key] == c].obs[context_batch_key].value_counts() > 3 for c in context_cell_types}
    context_batches = {c : context_batches[c][context_batches[c]].index.to_numpy() for c in context_cell_types}
    context_batches = {c : context_enc.transform(context_batches[c].reshape(-1, 1)).toarray().astype(np.float32)  for c in context_cell_types}
    context_batches['unknown'] = context_enc.transform(np.unique(context_batch_labels).reshape(-1, 1)).toarray().astype(np.float32)

    if target_cell_key == None:
        joint_cell_types = context_cell_types

    else:
        target_cell_labels = model.mdata.mod[model.target_dataset_key].obs[target_cell_key].to_numpy()
        target_cell_types = np.unique(target_cell_labels)
        joint_cell_types = np.intersect1d(context_cell_types, target_cell_types, return_indices=True)[0]
        target_batches = {c : model.mdata.mod[model.target_dataset_key][model.mdata.mod[model.target_dataset_key].obs[target_cell_key] == c].obs[target_batch_key].value_counts() > 3 for c in target_cell_types}
        target_batches = {c : target_batches[c][target_batches[c]].index.to_numpy() for c in target_cell_types}
        target_batches = {c : target_enc.transform(target_batches[c].reshape(-1, 1)).toarray().astype(np.float32)  for c in target_cell_types}
        target_batches['unknown'] = target_enc.transform(np.unique(target_batch_labels).reshape(-1, 1)).toarray().astype(np.float32)

    df_lfc_rho = pd.DataFrame(0, index=target_gene_names, columns=joint_cell_types)
    df_prob_rho = pd.DataFrame(0, index=target_gene_names, columns=joint_cell_types)

    df_lfc_mu = pd.DataFrame(0, index=target_gene_names, columns=joint_cell_types)
    df_prob_mu = pd.DataFrame(0, index=target_gene_names, columns=joint_cell_types)

    df_lfc_rho_random = pd.DataFrame(0, index=target_gene_names, columns=joint_cell_types)
    df_prob_rho_random = pd.DataFrame(0, index=target_gene_names, columns=joint_cell_types)

    df_lfc_mu_random = pd.DataFrame(0, index=target_gene_names, columns=joint_cell_types)
    df_prob_mu_random = pd.DataFrame(0, index=target_gene_names, columns=joint_cell_types)
    
    df_diff = pd.DataFrame(0, index=target_gene_names, columns=joint_cell_types)
    
    random_perm = np.random.permutation(len(target_gene_names))

    for cell_type in joint_cell_types:

        adata_context = model.mdata.mod[model.context_dataset_key][context_cell_index[cell_type]]
        adata_target = model.mdata.mod[model.target_dataset_key][target_cell_index[cell_type]]
        
        filtered_data_ind, _ = model.filter_outliers(adata_context.obsm['latent_mu'], confidence_level=confidence_level)
        adata_context = adata_context[filtered_data_ind]

        filtered_data_ind, _ = model.filter_outliers(adata_target.obsm['latent_mu'], confidence_level=confidence_level)
        adata_target = adata_target[filtered_data_ind]      

        latent_target = adata_target.obsm['latent_mu']
        latent_context = adata_context.obsm['latent_mu']
        nn = NearestNeighbors(n_neighbors=25, metric='cosine', algorithm='auto')
        nn.fit(latent_context)
        distances, indices = nn.kneighbors(latent_target)
        adata_target.obsm['cell_context_ind'] = indices

        steps = np.ceil(adata_target.n_obs/b_s).astype(int)    
        iterations = int(np.ceil(samples/adata_target.n_obs))

        with torch.no_grad():
            logfold_list_rho = []    
            logfold_list_mu = []        
            diff_list = []    

            logfold_list_rho_random = []    
            logfold_list_mu_random = []   

            for iter in range(iterations):
                for step in range(steps):   
                    batch_adata = adata_target[step*b_s:(step+1)*b_s]
                    context_cell_type = batch_adata.obs[batch_adata.uns['dataset_cell_key']].to_numpy()
                    target_cell_type = batch_adata.obs[batch_adata.uns['dataset_cell_key']].to_numpy() #np.array(['unknown']*batch_adata.n_obs)

                    context_labels = np.concatenate([context_batches[c] for c in context_cell_type])
                    target_labels = np.concatenate([target_batches[c] for c in target_cell_type])
                
                    context_labels = torch.from_numpy(context_labels).to(model.device)
                    target_labels = torch.from_numpy(target_labels).to(model.device)            

                    context_ind_batch = np.array([np.shape(context_batches[c])[0] for c in context_cell_type])
                    target_ind_batch = np.array([np.shape(target_batches[c])[0] for c in target_cell_type])

                    shape = np.shape(batch_adata.obsm['latent_sig'])

                    z = np.float32(batch_adata.obsm['latent_mu'] + batch_adata.obsm['latent_sig'] * np.random.rand(shape[0], shape[1])) 
                    target_l = np.exp(np.float32(batch_adata.obsm['l_mu'] + batch_adata.obsm['l_sig'] * np.random.rand(shape[0], 1)))
                    neigh_ind = batch_adata.obsm['cell_context_ind']
                    
                    context_l = np.exp(np.float32(adata_context.obsm['l_mu'][neigh_ind] + adata_context.obsm['l_sig'][neigh_ind] * np.random.rand(shape[0], 25, 1)))
                    context_l = context_l.mean(axis=1)

                    context_z = np.concatenate([np.tile(z[j], (i, 1)) for j, i in enumerate(context_ind_batch)])
                    target_z = np.concatenate([np.tile(z[j], (i, 1)) for j, i in enumerate(target_ind_batch)])

                    context_z = torch.from_numpy(context_z).to(model.device)
                    target_z = torch.from_numpy(target_z).to(model.device)

                    context_rho = model.context_decoder.decode_homologous(context_z, context_labels).cpu().numpy()
                    context_rho = model.average_slices(context_rho, context_ind_batch) 

                    target_rho = model.target_decoder.decode_homologous(target_z, target_labels).cpu().numpy()
                    target_rho = model.average_slices(target_rho, target_ind_batch)

                    context_mu = context_rho * context_l
                    target_mu = target_rho * target_l

                    logfold_list_rho.append(np.log2(target_rho+eps) - np.log2(context_rho+eps))
                    logfold_list_mu.append(np.log2(target_mu+eps) - np.log2(context_mu+eps))

                    logfold_list_rho_random.append(np.log2(target_rho+eps) - np.log2(context_rho[:, random_perm]+eps))
                    logfold_list_mu_random.append(np.log2(target_mu+eps) - np.log2(context_mu[:, random_perm]+eps))

                    diff_list.append(np.log2(context_l) - np.log2(target_l))

        diff_list = np.mean(np.concatenate(diff_list))

        logfold_list_rho = np.concatenate(logfold_list_rho)
        logfold_list_mu = np.concatenate(logfold_list_mu)    

        median_logfold_rho = np.median(logfold_list_rho, axis=0)
        median_logfold_mu = np.median(logfold_list_mu, axis=0)

        lfc_prob_rho = np.sum(np.where(np.abs(logfold_list_rho)>lfc_delta, 1, 0), axis=0) / np.shape(logfold_list_rho)[0]
        lfc_prob_mu = np.sum(np.where(np.abs(logfold_list_mu)>lfc_delta, 1, 0), axis=0) / np.shape(logfold_list_mu)[0]

        logfold_list_rho_random = np.concatenate(logfold_list_rho_random)
        logfold_list_mu_random = np.concatenate(logfold_list_mu_random)    

        median_logfold_rho_random = np.median(logfold_list_rho_random, axis=0)
        median_logfold_mu_random = np.median(logfold_list_mu_random, axis=0)

        lfc_prob_rho_random = np.sum(np.where(np.abs(logfold_list_rho_random)>lfc_delta, 1, 0), axis=0) / np.shape(logfold_list_rho_random)[0]
        lfc_prob_mu_random = np.sum(np.where(np.abs(logfold_list_mu_random)>lfc_delta, 1, 0), axis=0) / np.shape(logfold_list_mu_random)[0]

        df_lfc_rho[cell_type] = median_logfold_rho
        df_prob_rho[cell_type] = lfc_prob_rho

        df_lfc_mu[cell_type] = median_logfold_mu
        df_prob_mu[cell_type] = lfc_prob_mu

        df_lfc_rho_random[cell_type] = median_logfold_rho_random
        df_prob_rho_random[cell_type] = lfc_prob_rho_random

        df_lfc_mu_random[cell_type] = median_logfold_mu_random
        df_prob_mu_random[cell_type] = lfc_prob_mu_random
        
        df_diff[cell_type] = diff_list

    return df_lfc_rho, df_prob_rho, df_lfc_mu, df_prob_mu, df_lfc_rho_random, df_prob_rho_random, df_lfc_mu_random, df_prob_mu_random, df_diff
        


In [None]:
if target_key == 'human':

    df_lfc_rho, df_prob_rho, df_lfc_mu, df_prob_mu, df_lfc_rho_random, df_prob_rho_random, df_lfc_mu_random, df_prob_mu_random, df_diff = compute_logfold_change(scSpecies_model, 'cell_type_fine', 'cell_type_fine', eps=1e-6, lfc_delta=0.4, samples=1, b_s=128, confidence_level=0.95)         

    target_ind = np.array(scSpecies_model.target_param_dict['homologous_genes'])
    target_gene_names = scSpecies_model.mdata.mod[scSpecies_model.target_dataset_key].var_names.to_numpy()[target_ind]
    joint_cell_types = list(lfc_dict.keys())

    df_lfc_dat = pd.DataFrame(0, index=target_gene_names, columns=joint_cell_types)

    spear = {}
    pear = {}
    kend = {}

    for cell_type in joint_cell_types:
        adata_target = scSpecies_model.mdata.mod['human'][:, np.array(scSpecies_model.target_param_dict['homologous_genes'])]
        adata_context = scSpecies_model.mdata.mod['mouse'][:, np.array(scSpecies_model.context_param_dict['homologous_genes'])]

        adata_target = adata_target[adata_target.obs['cell_type_fine'] == cell_type]
        adata_context = adata_context[adata_context.obs['cell_type_fine'] == cell_type]    

        sc.pp.normalize_total(adata_context, target_sum=1e6, inplace=True)
        sc.pp.normalize_total(adata_target, target_sum=1e6, inplace=True)

        adata_context = np.mean(adata_context.X.toarray(), axis=0)
        adata_target = np.mean(adata_target.X.toarray(), axis=0)

        lfc = np.log2(adata_target+1) - np.log2(adata_context+1) 
        df_lfc_dat[cell_type] = lfc

        sort_data = np.argsort(lfc)
        sort_model = np.argsort(lfc_dict[cell_type]['lfc'])

        spear[cell_type] = spearmanr(lfc, lfc_dict[cell_type]['lfc']).statistic
        pear[cell_type] = pearsonr(lfc, lfc_dict[cell_type]['lfc']).statistic
        kend[cell_type] = kendalltau(np.arange(len(lfc)), np.argsort(lfc[sort_model])).statistic   

    lfc_delta = 1
    prob_delta = 0.9

    cell_types = df_prob_rho.columns

    n_cell_types = len(cell_types)
    n_cols = 4  
    n_rows = int(np.ceil(n_cell_types / n_cols))
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 3*n_rows), squeeze=False)
    axs = axs.flatten()


    for i, cell_type in enumerate(cell_types):
        ax = axs[i]

        ax.scatter(np.array(df_lfc_dat[cell_type][np.argsort(df_lfc_rho[cell_type])]), df_lfc_rho[cell_type][np.argsort(df_lfc_rho[cell_type])], c='darkgrey', s=12, edgecolor='k')
        ax.set_xlabel('LFC data-level')
        ax.set_ylabel('LFC scSpecies')
        ax.set_title(f"{cell_type}\n ρ: {str(round(spear[cell_type],2))}, PCC: {str(np.round(pear[cell_type],2))}, Kendall's τ: {str(round(kend[cell_type],2))}", pad=10, fontsize=14)

        lims = [
            np.min([ax.get_xlim(), ax.get_ylim()]),
            np.max([ax.get_xlim(), ax.get_ylim()])
        ]

        ax.plot(lims, lims, '--', color='red', linewidth=1) 

    for j in range(i+1, len(axs)):
        fig.delaxes(axs[j])

    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.suptitle(f"LFC values derived by scSpecies vs. LFC values by a data-level analysis. \n" 
                f"Average Spearman's ρ: {str(np.round(np.mean([spear[ct] for ct in spear.keys()]), 2))}, Pearson correlation: {str(np.round(np.mean([pear[ct] for ct in spear.keys()]), 2))}. Kendall's τ: {str(np.round(np.mean([kend[ct] for ct in spear.keys()]), 2))}.", fontsize=18)

    plt.savefig(save_path+save_key+"Comparison_data_vs_model.pdf", bbox_inches='tight')
    plt.show()

## 3) Data-level reconstruction error

In [None]:
def eval_model(model, eval_key, b_s=128):

    if eval_key == 'target':
        dataset_key = model.target_dataset_key
        decoder = model.target_decoder
        encoder_inner =  model.target_encoder_inner
        lib_encoder = model.target_lib_encoder
        encoder_outer = model.target_encoder_outer  

    if eval_key == 'context':
        dataset_key = model.context_dataset_key
        decoder = model.context_decoder
        encoder_inner =  model.target_encoder_inner
        lib_encoder = model.context_lib_encoder
        encoder_outer = model.context_encoder_outer      
        
        
    n_obs = model.mdata.mod[dataset_key].n_obs

    steps_per_epoch = int(np.ceil(n_obs / b_s))

    data_counts = torch.from_numpy(model.mdata.mod[dataset_key].X.toarray())

    encoder_outer.eval()
    lib_encoder.eval()        
    decoder.eval()
    encoder_inner.eval()

    encoder_outer.cpu()
    lib_encoder.cpu()        
    decoder.cpu()
    encoder_inner.cpu()

    interm = {i: [] for i in range(len(encoder_outer.model) + len(encoder_inner.model))}
    interm['mu'] = []

    likeli_list = []
    rec_list = []
    norm_list = []
    l_list = []
    l_mu_list = []
    l_log_sig_list = []

    def process_layer(layer, input_tensor):
        return layer(input_tensor)

    with torch.no_grad():
        for step in range(steps_per_epoch):
            #print(step) 
            start_idx = step * b_s
            end_idx = min((step + 1) * b_s, data_counts.size(0))
            batch_adata = model.mdata.mod[dataset_key][start_idx:end_idx]

            data_batch = data_counts[start_idx:end_idx]
            label_batch = torch.from_numpy(batch_adata.obsm['batch_label_enc'])

            x = torch.cat((data_batch, label_batch), dim=-1)  
            for i, layer in enumerate(encoder_outer.model):
                x = process_layer(layer, x)
                interm[i].append(x) 

            for j, layer in enumerate(encoder_inner.model):
                x = process_layer(layer, x)
                interm[len(encoder_outer.model)+j].append(x) 

            mu = encoder_inner.mu(x)
            interm['mu'].append(mu) 
            
            inter = encoder_outer(data_batch, label_batch)
            mu, log_sig = encoder_inner.encode(inter)                 
            z = mu + log_sig.exp() * torch.rand_like(log_sig)
            l_mu, l_log_sig = lib_encoder.encode(data_batch, label_batch)   
            l = torch.exp(l_mu + l_log_sig.exp() * torch.rand_like(l_log_sig))

            alpha, rho, pi_nlogit = decoder.decode(z, label_batch)   
            eps = 1e-7        
            alpha = torch.clamp(alpha, min=eps)
            rho = torch.clamp(rho, min=eps, max=1 - eps)
            pi = torch.sigmoid(-pi_nlogit)  
            mu = rho * l 

            total_count = alpha
            probs = torch.clamp(mu / (mu + alpha), min=eps, max=1 - eps)

            zero_mask = torch.bernoulli(pi.expand(1, *pi.shape))
            nb_dist = torch.distributions.NegativeBinomial(total_count=total_count, probs=probs)
            samples = nb_dist.sample((1,)) 
            samples = torch.where(zero_mask.bool(), torch.zeros_like(samples), samples).squeeze()

            log_alpha_mu = torch.log(alpha + mu)
            log_likelihood = torch.where(data_batch < eps,
                F.softplus(pi_nlogit + alpha * (torch.log(alpha) - log_alpha_mu)) - F.softplus(pi_nlogit),
                - F.softplus(pi_nlogit) + pi_nlogit 
                + alpha * (torch.log(alpha) - log_alpha_mu) + data_batch * (torch.log(mu) - log_alpha_mu) 
                + torch.lgamma(data_batch + alpha) - torch.lgamma(alpha) - torch.lgamma(1.0 + data_batch))

            likeli_list.append(log_likelihood.sum(-1))        
            rec_list.append(samples.squeeze())
            norm_list.append(rho.squeeze())
            l_list.append(l)
            l_mu_list.append(l_mu)
            l_log_sig_list.append(l_log_sig)
            
        likeli_list = torch.concatenate(likeli_list)  
        rec_list = torch.concatenate(rec_list)
        norm_list = torch.concatenate(norm_list)
        l_list = torch.concatenate(l_list)
        l_mu_list = torch.concatenate(l_mu_list)
        l_log_sig_list = torch.concatenate(l_log_sig_list)
        
    interm['data'] = sc.AnnData(model.mdata.mod[dataset_key].X)
        
    return interm, likeli_list, rec_list, norm_list, l_list, l_mu_list, l_log_sig_list, data_counts

scvi_model = scSpecies(device, 
                mdata, 
                path,
                context_dataset_key = target_key, 
                target_dataset_key = context_key,
                train_only_scvi=True,          
                random_seed = i*1234,
                use_lib_enc = True,                
                **h_dict
                )


scvi_model.train_context(30, save_model=False)
scvi_model.eval_context()

scSpecies_interm_context, scSpecies_likeli_list_context, scSpecies_rec_list_context, scSpecies_norm_list_context, scSpecies_l_list_context, scSpecies_l_mu_list_context, scSpecies_l_logsig_list_context, _ = eval_model(scSpecies_model, 'context', b_s=128)
scSpecies_interm_target, scSpecies_likeli_list_target, scSpecies_rec_list_target, scSpecies_norm_list_target, scSpecies_l_list_target, scSpecies_l_mu_list_target, scSpecies_l_logsig_list_target, _ = eval_model(scSpecies_model, 'target', b_s=128)
scvi_interm_target, scvi_likeli_list_target, scvi_rec_list_target, scvi_norm_list_target, scvi_l_list_target, scvi_l_mu_list_target, scvi_l_logsig_list_target, _ = eval_model(scvi_model, 'context', b_s=128)

scSpecies_model.mdata.mod[context_key].obsm['l_latent_mu'] = np.array(scSpecies_l_mu_list_context)
scSpecies_model.mdata.mod[context_key].obsm['l_latent_sig'] = np.array(scSpecies_l_logsig_list_context)
scSpecies_model.mdata.mod[target_key].obsm['l_latent_mu'] = np.array(scSpecies_l_mu_list_target)
scSpecies_model.mdata.mod[target_key].obsm['l_latent_sig'] = np.array(scSpecies_l_logsig_list_target)

In [None]:
data_counts_human = torch.from_numpy(mdata.mod[target_key].X.toarray())
data_counts_mouse = torch.from_numpy(mdata.mod[context_key].X.toarray())

rho_pre_hom = (data_counts_human+1e-6)/(
    data_counts_human[:, scSpecies_model.target_decoder.homologous_genes].sum(1).unsqueeze(-1) +1e-6*len(scSpecies_model.target_decoder.homologous_genes))
rho_pre_nonhom = (data_counts_human[:, scSpecies_model.target_decoder.non_hom_genes]+1e-6)/(
    data_counts_human[:, scSpecies_model.target_decoder.non_hom_genes].sum(1).unsqueeze(-1) +1e-6*len(scSpecies_model.target_decoder.non_hom_genes))

data_counts_norm = torch.cat((rho_pre_hom, rho_pre_nonhom), dim=-1)[:, scSpecies_model.target_decoder.gene_ind]

mdata.mod[target_key].layers['log1p'] = mdata.mod[target_key].X.copy()
sc.pp.log1p(mdata.mod[target_key], layer='log1p')
sc.tl.rank_genes_groups(mdata.mod[target_key], groupby='cell_type_fine', method='wilcoxon', layer='log1p')  

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.ndimage import convolve1d
from matplotlib.lines import Line2D


TITLE_SIZE = 17.5
AXIS_LABEL_SIZE = 16.5
TICK_LABEL_SIZE = 14
LEGEND_SIZE = 17.5


def get_top_marker_genes(cell_type, n_top=100):
    groups = mdata.mod[target_key].uns['rank_genes_groups']['names'].dtype.names
    top_genes = mdata.mod[target_key].uns['rank_genes_groups']['names'][:n_top][cell_type]
    return list(top_genes)

def histogram(data, num_bins=80, max_val=100):
    """Compute a smoothed density histogram using a Gaussian kernel."""
    data = np.asarray(data)
    if data.size == 0:
        return np.array([0]), np.array([0])
    kernel_size = 7
    sigma = 2.0
    xk = np.linspace(- (kernel_size // 2), kernel_size // 2, kernel_size)
    gauss_kernel = np.exp(-0.5 * (xk/sigma)**2)
    gauss_kernel /= gauss_kernel.sum()
    bin_edges = np.linspace(0, max_val, num_bins+1)
    hist_counts, _ = np.histogram(data, bins=bin_edges, density=True)

    first_bin_value = hist_counts[0]
    positive_bins = hist_counts[1:]
    positive_bins_smooth = convolve1d(positive_bins, gauss_kernel, mode='reflect')
    final_hist = np.concatenate(([first_bin_value], positive_bins_smooth))
    bin_mids = 0.5 * (bin_edges[:-1] + bin_edges[1:])
    return bin_mids, np.clip(final_hist, a_min=0, a_max=0.4)


cell_types_all = np.intersect1d(np.array(mdata.mod[target_key].obs.cell_type_fine.unique()),
                                np.array(mdata.mod[context_key].obs.cell_type_fine.unique()))
filtered = []
for ct in cell_types_all:
    ind = np.where(mdata.mod[target_key].obs.cell_type_fine == ct)[0]
    if len(ind) >= 250:
        filtered.append(ct)
cell_types = np.array(filtered)
n_cell_types = len(cell_types)
n_top_genes = 3  

fig, axes = plt.subplots(n_cell_types, 2 * n_top_genes,
                          figsize=(3.3 * 2 * n_top_genes, 1.75 * n_cell_types),
                          squeeze=False)

for ct_idx, ct in enumerate(cell_types):

    top_genes_human = get_top_marker_genes(ct, n_top=n_top_genes * 5)
    top_genes_human = [gene for gene in top_genes_human 
                       if np.array(mdata.mod[target_key][:, mdata.mod[target_key].var_names == gene].var.mouse_gene_names)[0]
                          in np.array(mdata.mod[context_key].var_names)]

    top_genes_human = np.array(mdata.mod[target_key][:, np.isin(mdata.mod[target_key].var_names, top_genes_human)]
                               .var_names)[:n_top_genes]
    top_genes_mouse = np.array(mdata.mod[target_key][:, np.isin(mdata.mod[target_key].var_names, top_genes_human)]
                               .var.mouse_gene_names)[:n_top_genes]

    ind_h = np.where(mdata.mod[target_key].obs.cell_type_fine == ct)[0]
    ind_m = np.where(mdata.mod[context_key].obs.cell_type_fine == ct)[0]
    gene_ind = np.array([np.where(mdata.mod[target_key].var_names == gene)[0][0] 
                         for gene in top_genes_human])
    gene_ind_m = np.array([np.where(mdata.mod[context_key].var_names == gene)[0][0] 
                           for gene in top_genes_mouse])

    origs = data_counts_human[ind_h][:, gene_ind]
    origs_mouse = data_counts_mouse[ind_m][:, gene_ind_m]
    recs_scvi = scvi_norm_list_target[ind_h][:, gene_ind] * scvi_l_list_target[ind_h]
    recs_scpecies = scSpecies_norm_list_target[ind_h][:, gene_ind] * scSpecies_l_list_target[ind_h]
    recs_scpecies_m = scSpecies_norm_list_context[ind_m][:, gene_ind_m] * scSpecies_l_list_context[ind_m]

    origs_norm = data_counts_norm[ind_h][:, gene_ind]
    recs_scvi_norm = scvi_norm_list_target[ind_h][:, gene_ind]
    recs_scpecies_norm = scSpecies_norm_list_target[ind_h][:, gene_ind]
    
    for j in range(n_top_genes):

        d_obs   = np.array(origs[:, j])
        d_obs_m = np.array(origs_mouse[:, j])
        d_scvi  = np.array(recs_scvi[:, j])
        d_scs   = np.array(recs_scpecies[:, j])
        d_scs_m = np.array(recs_scpecies_m[:, j])

        candidates = []
        if d_obs.size > 0:
            candidates.append(np.sort(d_obs)[int(len(d_obs)*0.96)])
        if d_scvi.size > 0:
            candidates.append(np.sort(d_scvi)[int(len(d_scvi)*0.96)])
        if d_scs.size > 0:
            candidates.append(np.sort(d_scs)[int(len(d_scs)*0.96)])

        max_val_data = np.min(candidates + [1000])
        max_val_data = np.max([max_val_data, 15, np.sort(d_scs_m)[int(len(d_scs_m)*0.9)]])
        if ct == 'Hepatocytes':
            max_val_data = 60
        num_bins_data = int(min(max_val_data, 80))
        
        x_obs, y_obs         = histogram(d_obs, num_bins=num_bins_data, max_val=max_val_data)
        x_obs_m, y_obs_m     = histogram(d_obs_m, num_bins=num_bins_data, max_val=max_val_data)
        x_scvi, y_scvi       = histogram(d_scvi, num_bins=num_bins_data, max_val=max_val_data)
        x_scs, y_scs         = histogram(d_scs, num_bins=num_bins_data, max_val=max_val_data)

        ax_data = axes[ct_idx, 2*j]
        ax_data.step(x_obs_m, y_obs_m, where='mid', color='orange', lw=2)
        ax_data.step(x_obs, y_obs, where='mid', color='purple', lw=2)
        ax_data.step(x_scvi, y_scvi, where='mid', color='darkgreen', lw=2)
        ax_data.step(x_scs, y_scs, where='mid', color='blue', lw=2)
        ax_data.set_title(f"{top_genes_human[j]}/{top_genes_mouse[j]} (Raw)", fontsize=TITLE_SIZE)
        ax_data.tick_params(labelsize=TICK_LABEL_SIZE)

        ax_data.set_xlabel('')

        n_data  = np.array(origs_norm[:, j])
        n_scvi  = np.array(recs_scvi_norm[:, j])
        n_scs   = np.array(recs_scpecies_norm[:, j])
        
        max_val_norm = max(#np.max(n_data)*1.1 if n_data.size > 0 else 0,
                           np.max(n_scvi)*1.1 if n_scvi.size > 0 else 0,
                           np.max(n_scs)*1.1 if n_scs.size > 0 else 0,
                           0.005)
        bins_norm = np.linspace(0, max_val_norm, 100)
        counts_data, _ = np.histogram(n_data, bins=bins_norm, density=True)
        counts_scvi, _ = np.histogram(n_scvi, bins=bins_norm, density=True)
        counts_scs, _  = np.histogram(n_scs, bins=bins_norm, density=True)
        x_vals_norm = (bins_norm[:-1] + bins_norm[1:]) / 2

        ax_norm = axes[ct_idx, 2*j+1]
        ax_norm.step(x_vals_norm, counts_scvi, where='mid', color='deepskyblue', lw=2)
        ax_norm.step(x_vals_norm, counts_scs, where='mid', color='limegreen', lw=2)
        ax_norm.set_title(f"{top_genes_human[j]} (Normalized)", fontsize=TITLE_SIZE)
        ax_norm.tick_params(labelsize=TICK_LABEL_SIZE)
        ax_norm.set_xlabel('')
        ax_norm.set_ylabel('')


for ct_idx, ct in enumerate(cell_types):
    
    
    if ct == 'Cholangiocytes':
        axes[ct_idx, 0].set_ylabel(f"{ct}\n p.d.f", fontsize=AXIS_LABEL_SIZE-2)
    
    elif ct == 'Cytotoxic CD8+':
        axes[ct_idx, 0].set_ylabel(f"{ct}\n p.d.f", fontsize=AXIS_LABEL_SIZE-2)
        
    else: 
        axes[ct_idx, 0].set_ylabel(f"{ct}\n p.d.f", fontsize=AXIS_LABEL_SIZE)    
    
    if ct_idx == len(cell_types) - 1:
        axes[ct_idx, 0].set_xlabel(f"Gene expression levels", fontsize=AXIS_LABEL_SIZE)
        axes[ct_idx, 1].set_xlabel(f"Normalized gene expression", fontsize=AXIS_LABEL_SIZE)
        axes[ct_idx, 2].set_xlabel(f"Gene expression levels", fontsize=AXIS_LABEL_SIZE)        
        axes[ct_idx, 3].set_xlabel(f"Normalized gene expression", fontsize=AXIS_LABEL_SIZE)        
        axes[ct_idx, 4].set_xlabel(f"Gene expression levels", fontsize=AXIS_LABEL_SIZE)
        axes[ct_idx, 5].set_xlabel(f"Normalized gene expression", fontsize=AXIS_LABEL_SIZE)        

norm_handles = [Line2D([0], [0], color='deepskyblue', lw=2, label=f'Normalized {target_key} scVI gene expression'),
                Line2D([0], [0], color='limegreen', lw=2, label=f'Normalized {target_key} scSpecies gene expression'),
                Line2D([0], [0], color='blue', lw=2, label=f'scVI {target_key} NB mean parameter'),
                Line2D([0], [0], color='darkgreen', lw=2, label=f'scSpecies {target_key} NB mean parameter'),
                Line2D([0], [0], color='purple', lw=2, label=f'{target_key} gene expression data distribution'.capitalize()),
                Line2D([0], [0], color='orange', lw=2, label=f'{context_key} gene expression data distribution'.capitalize()),
                ]
fig.legend(norm_handles, [h.get_label() for h in norm_handles],
           loc='upper center', bbox_to_anchor=(0.5, 0.98), ncol=3, fontsize=LEGEND_SIZE)

plt.tight_layout(rect=[0, 0.04, 1, 0.95])
plt.savefig(save_path+save_key+"gene_expression_levels.pdf")
plt.show()


## 4) ELBO convergence

In [None]:
context_likeli = np.array(scvi_model.context_likeli_hist_dict) 
target_likeli = np.array(scSpecies_model.target_likeli_hist_dict) 

adaptive_smoothed_context = []
adaptive_smoothed_target = []
adaptive_indices = []

for i in range(10, len(context_likeli)):
    smooth_window = max(10, int(i*0.02))   
    if i >= smooth_window:
        adaptive_smoothed_context.append(np.mean(context_likeli[i - smooth_window:i]))
        adaptive_smoothed_target.append(np.mean(target_likeli[i - smooth_window:i]))
        adaptive_indices.append(i)

adaptive_smoothed_context = np.array(adaptive_smoothed_context)
adaptive_smoothed_target = np.array(adaptive_smoothed_target)
adaptive_indices = np.array(adaptive_indices)

plt.figure(figsize=(10, 6))
plt.plot(adaptive_indices, adaptive_smoothed_context, label="scVI", linestyle='-', color='blue')
plt.plot(adaptive_indices, adaptive_smoothed_target, label="scSpecies", linestyle='-', color='orange')

plt.xscale("log")
plt.xticks([10, 100, 1000, 10000], labels=["10", "100", "1000", "10000"])

plt.xlabel("Iteration")
plt.ylabel("Negative Log-likelihood")
plt.xlim(10,35000)
plt.ylim(1100, 1700)
plt.title(f"Log-likelihood smoothed over the last 2% of iterations")
plt.legend()
plt.grid(True, which="both", linestyle="--", linewidth=0.5)
plt.tight_layout()
plt.savefig(save_path+save_key+"likelihood_during_training.pdf")
plt.show()

print('scVI_final_ELBO', adaptive_smoothed_context[-1])
print('scSpecies_final_ELBO', adaptive_smoothed_target[-1])

## 5) Overrepresentation Enrichment Analysis

In [None]:
def plot_ora(ora_raw, N=10, q=None, cols=4, wrap=28, figsize_scale=4, title="Over-representation analysis, top 8 per cell type", legend_ncol=2):
    celltypes = list(ora_raw.keys())
    rows = max(1, math.ceil(len(celltypes)/cols))
    fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(4*cols, figsize_scale*rows), squeeze=False)
    eps = np.nextafter(0,1)
    handles_labels = None
    for i, ct in enumerate(celltypes):
        ax = axes[i//cols, i%cols]
        up = ora_raw[ct].get('up', pd.DataFrame(columns=['Term','Adjusted P-value']))[['Term','Adjusted P-value']].dropna().copy()
        down = ora_raw[ct].get('down', pd.DataFrame(columns=['Term','Adjusted P-value']))[['Term','Adjusted P-value']].dropna().copy()
        if up.empty and down.empty:
            ax.axis('off'); continue
        up['dir'] = 'up'; down['dir'] = 'down'
        df = pd.concat([up, down], ignore_index=True)
        if q is not None: df = df[df['Adjusted P-value']<=q]
        if df.empty:
            ax.axis('off'); continue
        df['score'] = -np.log10(df['Adjusted P-value'].replace(0, eps))
        df['label'] = ["\n".join(textwrap.wrap(t, width=wrap)) for t in df['Term']]
        df = df.nlargest(N, 'score')
        sns.barplot(x='score', y='label', data=df, hue='dir', hue_order=['up','down'], dodge=False, orient='h', ax=ax)
        if handles_labels is None:
            handles_labels = ax.get_legend_handles_labels()
        if ax.legend_ is not None:
            ax.legend_.remove()
        ax.set_title(ct, fontsize=12)
        ax.set_xlabel('-log10(adj. p-val.)'); ax.set_ylabel('')
        ax.tick_params(axis='y', which='both', labelsize=8)
    for j in range(len(celltypes), rows*cols):
        fig.delaxes(axes.flatten()[j])
    plt.tight_layout(rect=[0, 0, 1, 0.82])
    fig.suptitle(title, y=0.85, fontsize=14)
    if handles_labels is not None:
        handles, labels = handles_labels
        fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 0.837), ncol=legend_ncol, frameon=True)  
        plt.savefig(save_path+save_key+"ORA.pdf", bbox_inches='tight')  

if target_key == 'human':

    libs = set(gp.get_library_name(organism='Human'))
    def pick(regex):
        c = [x for x in libs if re.search(regex, x, re.I)]
        c = sorted(c, key=lambda s: int(re.search(r'(\d{4})', s).group(1)) if re.search(r'(\d{4})', s) else 0)
        return c[-1] if c else None

    ORA_LIBS = [pick(r'^Reactome'), pick(r'^KEGG_.*Human'), pick(r'^GO_Biological_Process'), pick(r'Hallmark')]
    ORA_LIBS = [x for x in ORA_LIBS if x]
    print(ORA_LIBS)

    tau = 0.95
    lfc_thr = 1.0

    universe = sorted(set().union(*[set(df.index) for df in lfc_dict.values()]))

    deg = {}
    for ct, df in lfc_dict.items():
        up = df.index[(df['lfc']>=lfc_thr) & (df['p']>=tau)]
        down = df.index[(df['lfc']<=-lfc_thr) & (df['p']>=tau)]
        if len(up) >= 5 or len(down) >= 5:
            deg[ct] = {'up': list(up), 'down': list(down)}

    def run_enrichr(genes):
        if len(genes) < 5:
            return pd.DataFrame(columns=['Term','Adjusted P-value'])
        r = gp.enrichr(gene_list=genes, gene_sets=ORA_LIBS, organism='human', background=universe, outdir=None, no_plot=True).results
        return r[['Term','Adjusted P-value']].sort_values('Adjusted P-value').reset_index(drop=True)

    ora_raw = {}
    for ct in deg.keys():
        ru = run_enrichr(deg[ct]['up'])
        rd = run_enrichr(deg[ct]['down'])
        ora_raw[ct] = {'up': ru, 'down': rd}

    plot_ora(ora_raw, N=8)        


## 6) Pathway Comparison

In [None]:
def decode_samples(model, context_cell_key, target_cell_key, samples=5000, b_s=128, confidence_level=0.9):
    model.context_decoder.eval()   
    model.target_decoder.eval()    
    model.context_encoder_inner.eval()   
    model.target_encoder_inner.eval() 
    model.context_encoder_outer.eval()   
    model.target_encoder_outer.eval() 
    model.context_lib_encoder.eval()   
    model.target_lib_encoder.eval()         

    context_cell_labels = model.mdata.mod[model.context_dataset_key].obs[context_cell_key].to_numpy()
    context_cell_types = np.unique(context_cell_labels)

    target_cell_labels = model.mdata.mod[model.target_dataset_key].obs[target_cell_key].to_numpy()
    target_cell_types = np.unique(target_cell_labels)
    target_cell_index = {c : np.where(target_cell_labels == c)[0] for c in target_cell_types}

    context_batch_key = model.mdata.mod[model.context_dataset_key].uns['dataset_batch_key']
    target_batch_key = model.mdata.mod[model.target_dataset_key].uns['dataset_batch_key']
    
    context_batch_labels = model.mdata.mod[model.context_dataset_key].obs[context_batch_key].to_numpy().reshape(-1, 1)
    target_batch_labels = model.mdata.mod[model.target_dataset_key].obs[target_batch_key].to_numpy().reshape(-1, 1)

    context_enc = OneHotEncoder()
    context_enc.fit(context_batch_labels)

    target_enc = OneHotEncoder()
    target_enc.fit(target_batch_labels)

    context_batches = {c : model.mdata.mod[model.context_dataset_key][model.mdata.mod[model.context_dataset_key].obs[context_cell_key] == c].obs[context_batch_key].value_counts() > 3 for c in context_cell_types}
    context_batches = {c : context_batches[c][context_batches[c]].index.to_numpy() for c in context_cell_types}
    context_batches = {c : context_enc.transform(context_batches[c].reshape(-1, 1)).toarray().astype(np.float32)  for c in context_cell_types}
    context_batches['unknown'] = context_enc.transform(np.unique(context_batch_labels).reshape(-1, 1)).toarray().astype(np.float32)

    if target_cell_key == None:
        joint_cell_types = context_cell_types

    else:
        target_cell_labels = model.mdata.mod[model.target_dataset_key].obs[target_cell_key].to_numpy()
        target_cell_types = np.unique(target_cell_labels)
        joint_cell_types = np.intersect1d(context_cell_types, target_cell_types, return_indices=True)[0]
        target_batches = {c : model.mdata.mod[model.target_dataset_key][model.mdata.mod[model.target_dataset_key].obs[target_cell_key] == c].obs[target_batch_key].value_counts() > 3 for c in target_cell_types}
        target_batches = {c : target_batches[c][target_batches[c]].index.to_numpy() for c in target_cell_types}
        target_batches = {c : target_enc.transform(target_batches[c].reshape(-1, 1)).toarray().astype(np.float32)  for c in target_cell_types}
        target_batches['unknown'] = target_enc.transform(np.unique(target_batch_labels).reshape(-1, 1)).toarray().astype(np.float32)

    context_rho_dict = {}
    target_rho_dict = {}

    for cell_type in joint_cell_types:
        adata_target = model.mdata.mod[model.target_dataset_key][target_cell_index[cell_type]]

        filtered_data_ind, _ = model.filter_outliers(adata_target.obsm['latent_mu'], confidence_level=confidence_level)
        adata_target = adata_target[filtered_data_ind]      

        steps = np.ceil(adata_target.n_obs/b_s).astype(int)    
        iterations = int(np.ceil(samples/adata_target.n_obs))

        with torch.no_grad():
            context_rho_dict[cell_type] = []    
            target_rho_dict[cell_type] = []  

            for iter in range(iterations):
                for step in range(steps):   
                    batch_adata = adata_target[step*b_s:(step+1)*b_s]
                    context_cell_type = batch_adata.obs[batch_adata.uns['dataset_cell_key']].to_numpy()
                    target_cell_type = batch_adata.obs[batch_adata.uns['dataset_cell_key']].to_numpy() #np.array(['unknown']*batch_adata.n_obs)

                    context_labels = np.concatenate([context_batches[c] for c in context_cell_type])
                    target_labels = np.concatenate([target_batches[c] for c in target_cell_type])
                
                    context_labels = torch.from_numpy(context_labels).to(model.device)
                    target_labels = torch.from_numpy(target_labels).to(model.device)            

                    context_ind_batch = np.array([np.shape(context_batches[c])[0] for c in context_cell_type])
                    target_ind_batch = np.array([np.shape(target_batches[c])[0] for c in target_cell_type])

                    shape = np.shape(batch_adata.obsm['z_sig'])
                    z = np.float32(batch_adata.obsm['z_mu'] + batch_adata.obsm['z_sig'] * np.random.rand(shape[0], shape[1])) 

                    context_z = np.concatenate([np.tile(z[j], (i, 1)) for j, i in enumerate(context_ind_batch)])
                    target_z = np.concatenate([np.tile(z[j], (i, 1)) for j, i in enumerate(target_ind_batch)])

                    context_z = torch.from_numpy(context_z).to(model.device)
                    target_z = torch.from_numpy(target_z).to(model.device)

                    context_rho = model.context_decoder.decode_homologous(context_z, context_labels).cpu().numpy()
                    context_rho = model.average_slices(context_rho, context_ind_batch)

                    target_rho = model.target_decoder.decode_homologous(target_z, target_labels).cpu().numpy()
                    target_rho = model.average_slices(target_rho, target_ind_batch)

                    context_rho_dict[cell_type].append(context_rho)
                    target_rho_dict[cell_type].append(target_rho)

        target_rho_dict[cell_type] = np.concatenate(target_rho_dict[cell_type])[:samples]
        context_rho_dict[cell_type] = np.concatenate(context_rho_dict[cell_type])[:samples]  

    return target_rho_dict, context_rho_dict


def load_and_filter_pathways(gmt_path, adata, min_genes=5):

    pathways = {}
    with open(gmt_path, 'r') as f:
        for line in f:
            parts = line.strip().split('\t')
            name = parts[0]
            gene_list = parts[2:]
            pathways[name] = gene_list

    # Filter
    var_set = set(adata.var_names)
    filtered = {}
    for name, genes in pathways.items():
        overlap = var_set.intersection(genes)
        if len(overlap) >= min_genes:
            filtered[name] = list(overlap)

    return filtered


def ensure_model_device(model):
    device = torch.device(model.device if isinstance(model.device, str) else model.device)
    for m in [
        model.context_decoder, model.target_decoder,
        model.context_encoder_inner, model.target_encoder_inner,
        model.context_encoder_outer, model.target_encoder_outer,
        model.context_lib_encoder, model.target_lib_encoder
    ]:
        m.to(device)
    return device

def mode_histogram(x, bins='fd'):         

    counts, edges = np.histogram(x, bins=bins)
    j = np.argmax(counts)                
    return (edges[j] + edges[j+1]) / 2.0



In [None]:
if target_key == 'human':

    ensure_model_device(scSpecies_model)

    target_rho_dict, context_rho_dict = decode_samples(scSpecies_model, 'cell_type_fine', 'cell_type_fine', samples=1000, b_s=128, confidence_level=0.95)

    adata_h = ad.concat([ad.AnnData(target_rho_dict[key]) for key in target_rho_dict.keys()])
    adata_h.var_names = scSpecies_model.mdata.mod['human'][:, scSpecies_model.target_param_dict['homologous_genes']].var_names
    adata_h.obs['cell_type_fine'] = np.concat([[key]*np.shape(target_rho_dict[key])[0] for key in target_rho_dict.keys()])

    adata_m = ad.concat([ad.AnnData(context_rho_dict[key]) for key in context_rho_dict.keys()])
    adata_m.var_names = adata_h.var_names
    adata_m.obs['cell_type_fine'] = np.concat([[key]*np.shape(context_rho_dict[key])[0] for key in context_rho_dict.keys()])

    gene_sets_path = path+'dataset/c2.all.v2024.1.Hs.symbols.gmt'
    pathways = load_and_filter_pathways(gene_sets_path, adata_h)

    adata = adata_m.concatenate(
        adata_h,
        batch_key="species",
        batch_categories=["mouse", "human"]
    )


    sc.pp.log1p(adata)

    for i,key in enumerate(pathways.keys()):
        sc.tl.score_genes(adata, gene_list=pathways[key], score_name=key)

    scores_cols = list(adata.obs.columns[2:])

    summary_scores = {ct: {} for ct in np.unique(adata.obs['cell_type_fine'].values)}

    for ct in summary_scores.keys():
        adata_cell_m = adata[adata.obs['cell_type_fine'] == ct]
        adata_cell_m = adata_cell_m[adata_cell_m.obs['species'] == 'mouse']
        adata_cell_h = adata[adata.obs['cell_type_fine'] == ct]
        adata_cell_h = adata_cell_h[adata_cell_h.obs['species'] == 'human']
        summary_scores[ct] = {pathway: (mode_histogram(adata_cell_m.obs[pathway]), mode_histogram(adata_cell_h.obs[pathway]), mode_histogram(adata_cell_m.obs[pathway]) - mode_histogram(adata_cell_h.obs[pathway])) for pathway in scores_cols}

    mean_diff = {ct: np.mean([np.abs(summary_scores[ct][pathway][-1]) for pathway in scores_cols]) for ct in np.unique(adata.obs['cell_type_fine'].values)}
    top_down = {ct: (scores_cols[np.argmin([summary_scores[ct][pathway][-1] for pathway in scores_cols])], np.min([summary_scores[ct][pathway][-1] for pathway in scores_cols])) for ct in np.unique(adata.obs['cell_type_fine'].values)}
    top_up = {ct: (scores_cols[np.argmax([summary_scores[ct][pathway][-1] for pathway in scores_cols])], np.max([summary_scores[ct][pathway][-1] for pathway in scores_cols])) for ct in np.unique(adata.obs['cell_type_fine'].values)}
    top_diff = {ct: (scores_cols[np.argmax([summary_scores[ct][pathway][-1] for pathway in scores_cols])], np.max([np.abs(summary_scores[ct][pathway][-1]) for pathway in scores_cols])) for ct in np.unique(adata.obs['cell_type_fine'].values)}

    def wrap_name(name, width=30):
        pretty = name.replace('_',' ')
        return '\n'.join(textwrap.wrap(pretty, width=width))

    cell_types = list(top_down.keys())
    n = len(cell_types)
    cols = 5
    rows = math.ceil(n/cols)

    fig, axes = plt.subplots(rows, cols, figsize=(cols*4, rows*4), sharey=False)
    axes = axes.flatten()

    for ax, ct in zip(axes, cell_types):
        pw_up   = top_up[ct][0]
        pw_down = top_down[ct][0]
        mean_val = mean_diff[ct]

        df_ct = (
            adata[adata.obs['cell_type_fine'] == ct].obs
            .loc[:, ['species', pw_up, pw_down]]
            .melt(id_vars='species',
                value_vars=[pw_up, pw_down],
                var_name='pathway',
                value_name='score')
        )

        df_ct['pathway'] = df_ct['pathway'].map({
            pw_up:   'Top-up',
            pw_down: 'Top-down'
        })

        sns.violinplot(
            ax=ax, data=df_ct,
            x='pathway', y='score', hue='species',
            palette={'human':'C0','mouse':'C1'},
            dodge=True, inner='quartile'
        )

        title = (
            f"{ct}, Δ mean = {mean_val:.6f}\n"
            f"Up:   {wrap_name(pw_up)}\n"
            f"Down: {wrap_name(pw_down)}"
        )
        ax.set_title(title, fontsize=13, pad=6)

        ax.set_xlabel('')
        ax.set_ylabel('')
        ax.tick_params(axis='x', labelsize=11)
        ax.tick_params(axis='y', labelsize=11)
        ax.legend_.remove()

    for ax in axes[n:]:
        ax.axis('off')

    handles, labels = axes[0].get_legend_handles_labels()
    labels = [lab.capitalize() for lab in labels]

    fig.legend(
        handles, labels, title='Species',
        ncol=2, loc='lower center',
        bbox_to_anchor=(0.5, -0.02)
    )

    fig.suptitle(
        f"Comparison of pathway with the highest difference in mean activity scores across {target_key} vs {context_key} liver cell types",
        fontsize=18, fontweight='bold'
    )
    plt.tight_layout(rect=[0, 0.03, 1, 1])
    plt.savefig(save_path+save_key+"pathways.pdf", bbox_inches='tight')

    plt.show()