In [5]:
from sklearn.metrics.pairwise import rbf_kernel
import numpy as np
import anndata as ad
perio=ad.read('pred_perio_corrected_dbl (3).h5ad')
target=perio[(perio.obs['state']=='true_corrected')&(perio.obs['drug']=='P. gingivalis')].X
transport=perio[perio.obs['state']=='predicted'].X
def mmd_distance(x, y, gamma):
    xx = rbf_kernel(x, x, gamma)
    xy = rbf_kernel(x, y, gamma)
    yy = rbf_kernel(y, y, gamma)

    return xx.mean() + yy.mean() - 2 * xy.mean()


def compute_scalar_mmd(target, transport, gammas=None):
    if gammas is None:
        gammas = [2, 1, 0.5, 0.1, 0.01, 0.005]

    def safe_mmd(*args):
        try:
            mmd = mmd_distance(*args)
        except ValueError:
            mmd = np.nan
        return mmd

    return np.mean(list(map(lambda x: safe_mmd(target, transport, x), gammas)))


Observation names are not unique. To make them unique, call `.obs_names_make_unique`.


In [None]:
compute_scalar_mmd(target, transport, gammas=None)

In [1]:
import os

# Définition des listes de stimulations et de types cellulaires
perio_stim_list = ['TNFa', 'P._gingivalis']
perio_cell_list = [
    'Granulocytes (CD45-CD66+)', 'B-Cells (CD19+CD3-)', 'Classical Monocytes (CD14+CD16-)',
    'MDSCs (lin-CD11b-CD14+HLADRlo)', 'mDCs (CD11c+HLADR+)', 'pDCs(CD123+HLADR+)',
    'Intermediate Monocytes (CD14+CD16+)', 'Non-classical Monocytes (CD14-CD16+)',
    'CD56+CD16- NK Cells', 'CD56loCD16+NK Cells', 'NK Cells (CD7+)',
    'CD4 T-Cells', 'Tregs (CD25+FoxP3+)', 'CD8 T-Cells', 'CD8-CD4- T-Cells'
]

# Création du dossier de sortie
config_dir = "./yaml/perio_just_concat"
os.makedirs(config_dir, exist_ok=True)

# Modèle de configuration YAML
template_yaml = """data:
  features: ./datasets/perio_just_concat/features.txt
  path: {data_path}
  condition: drug
  source: Unstim
  type: cell

dataloader:
  batch_size: 256
  shuffle: true

datasplit:
  groupby: drug
  name: train_test
  test_size: 0.2
"""

# Générer les fichiers de configuration
for stim in perio_stim_list:
    for cell in perio_cell_list:
        data_path = f"./datasets/perio_just_concat/perio_data_sherlock_{stim}_{cell.replace(' ', '_')}.h5ad"
        config_content = template_yaml.format(data_path=data_path)

        config_filename = f"perio_data_sherlock_{stim}_{cell.replace(' ', '_')}_train.yaml"
        config_filepath = os.path.join(config_dir, config_filename)

        with open(config_filepath, "w") as f:
            f.write(config_content)

        print(f"Config generated: {config_filepath}")


Config generated: ./yaml/perio_just_concat/perio_data_sherlock_TNFa_Granulocytes_(CD45-CD66+)_train.yaml
Config generated: ./yaml/perio_just_concat/perio_data_sherlock_TNFa_B-Cells_(CD19+CD3-)_train.yaml
Config generated: ./yaml/perio_just_concat/perio_data_sherlock_TNFa_Classical_Monocytes_(CD14+CD16-)_train.yaml
Config generated: ./yaml/perio_just_concat/perio_data_sherlock_TNFa_MDSCs_(lin-CD11b-CD14+HLADRlo)_train.yaml
Config generated: ./yaml/perio_just_concat/perio_data_sherlock_TNFa_mDCs_(CD11c+HLADR+)_train.yaml
Config generated: ./yaml/perio_just_concat/perio_data_sherlock_TNFa_pDCs(CD123+HLADR+)_train.yaml
Config generated: ./yaml/perio_just_concat/perio_data_sherlock_TNFa_Intermediate_Monocytes_(CD14+CD16+)_train.yaml
Config generated: ./yaml/perio_just_concat/perio_data_sherlock_TNFa_Non-classical_Monocytes_(CD14-CD16+)_train.yaml
Config generated: ./yaml/perio_just_concat/perio_data_sherlock_TNFa_CD56+CD16-_NK_Cells_train.yaml
Config generated: ./yaml/perio_just_concat/peri