## Infer causal Structure on ScanPy Data

#### Structure:
A: Load Data from file & look at structure

B: Algorithms
1. GRNBoost2
2. GIES
3. DCDI

Dependencies:
 use a conda-env with:
 - scanpy python-igraph leidenalg

 GRNBoost:
 - conda install -c bioconda arboreto
 
 GIES:
 - pip install gies

In [1]:
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt

import scp_infer as scpi

from filter_adata import *
from plot_adata import *



In [2]:
results_file = '../data/edited/Schraivogel_chr8_ad-scaled_10gene.h5ad'  # the file that will store the analysis results

1. Read File

In [3]:
adata = sc.read_h5ad(results_file)

Check what count distribution looks like:

In [4]:
#1st step: extract data matrix, gene names and cell names from the AnnData object
gene_names = adata.var_names
cell_names = adata.obs_names

#print("Data matrix shape: ", df.shape)
#print("sample: ", df.iloc[0:3,0:3])
print(len(gene_names),"genes: ", [i for i in gene_names[:3]])
print(len(cell_names),"cells: ", [i for i in cell_names[:1]])

#2nd step: extract metadata from the AnnData object and exctract perturbation information
metadata = adata.obs
metadata.head()

# Look at more perturbation labels
# print(adata.obs['perturbation'].astype(str).copy()[1000:1020])

10 genes:  ['CCNE2', 'CPQ', 'DSCC1']
3638 cells:  ['TGATTGACAAACCTGAGAGCTATA-sample_14']


Unnamed: 0_level_0,replicate,tissue_type,cell_line,cancer,disease,celltype,organism,perturbation,perturbation_type,ncounts,...,percent_ribo,nperts,n_genes_by_counts,total_counts,n_genes,total_counts_mt,pct_counts_mt,non-targeting,multiplet,gene_pert
cell_barcode,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
TGATTGACAAACCTGAGAGCTATA-sample_14,sample_14,cell_line,K562,True,chronic myelogenous leukemia,lymphoblasts,human,RIPK2,CRISPR,1247.0,...,7.618284,3,61,1247.0,61,0.0,0.0,False,False,True
TGATTGACAAACCTGAGTCGAGTG-sample_14,sample_14,cell_line,K562,True,chronic myelogenous leukemia,lymphoblasts,human,DSCC1,CRISPR,2615.0,...,6.462715,3,65,2615.0,65,0.0,0.0,False,False,True
TGATTGACAAACCTGCAACTTGAC-sample_14,sample_14,cell_line,K562,True,chronic myelogenous leukemia,lymphoblasts,human,OXR1,CRISPR,1445.0,...,9.757786,3,63,1445.0,63,0.0,0.0,False,False,True
TGATTGACAAACCTGCAGTATCTG-sample_14,sample_14,cell_line,K562,True,chronic myelogenous leukemia,lymphoblasts,human,non-targeting,CRISPR,1711.0,...,8.065459,2,72,1711.0,72,0.0,0.0,True,False,False
TGATTGACAAACCTGCATGCAATC-sample_14,sample_14,cell_line,K562,True,chronic myelogenous leukemia,lymphoblasts,human,STK3,CRISPR,974.0,...,11.704312,3,60,974.0,60,0.0,0.0,False,False,True


In [5]:
# print([i for i in adata.var['mean'][0:10]])
# print([i for i in adata.var['std'][0:10]])
# print corresponding perturbation labels
print('Perturbations: ', [i for i in adata.obs['perturbation'][:10]])

print_expression_mean_std(adata)

Perturbations:  ['RIPK2', 'DSCC1', 'OXR1', 'non-targeting', 'STK3', 'FAM83A', 'non-targeting', 'non-targeting', 'RIPK2', 'non-targeting']

Perturbed Gene Expression:
Mean:  -1.3400286092269702
Std:  1.4459721913667531
Min:  -10.812207221984863
Max:  1.3513376712799072
95% percentile:  -2.564718008041382  -  0.5875126719474792



Non-Target Gene Expression:
Mean:  -0.08327366830923486
Std:  0.9816791315723081
Min:  -2.78004789352417
Max:  1.833469271659851
95% percentile:  -1.8425889492034913  -  0.9762625455856323


# B. Algorithms

### 1. GRNBoost2

In [6]:
run_GRNBoost = True
if run_GRNBoost:
    grnb = scpi.inference.grnboost2.GRNBoost2Imp(adata, verbose= True)
    grnb.convert_data()
    grnb.infer(plot=True)

Running GRNBoost2
preparing dask client


parsing input
creating dask graph
4 partitions
computing dask graph


2024-03-07 16:01:41,775 - distributed.protocol.core - CRITICAL - Failed to deserialize
Traceback (most recent call last):
  File "/home/jans/miniconda3/envs/py-infer/lib/python3.12/site-packages/distributed/protocol/core.py", line 160, in loads
    return msgpack.loads(
           ^^^^^^^^^^^^^^
  File "/home/jans/miniconda3/envs/py-infer/lib/python3.12/site-packages/msgpack/fallback.py", line 136, in unpackb
    raise ExtraData(ret, unpacker._get_extradata())
msgpack.exceptions.ExtraData: unpack(b) received extra data.
2024-03-07 16:01:41,850 - distributed.core - ERROR - Exception while handling op register-client
Traceback (most recent call last):
  File "/home/jans/miniconda3/envs/py-infer/lib/python3.12/site-packages/distributed/core.py", line 968, in _handle_comm
    result = await result
             ^^^^^^^^^^^^
  File "/home/jans/miniconda3/envs/py-infer/lib/python3.12/site-packages/distributed/scheduler.py", line 5532, in add_client
    await self.handle_stream(comm=comm, extr

shutting down client and local cluster


2024-03-07 16:01:42,514 - distributed.client - ERROR - 
Traceback (most recent call last):
  File "/home/jans/miniconda3/envs/py-infer/lib/python3.12/site-packages/distributed/utils.py", line 832, in wrapper
    return await func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jans/miniconda3/envs/py-infer/lib/python3.12/site-packages/distributed/client.py", line 1328, in _reconnect
    await self._ensure_connected(timeout=timeout)
  File "/home/jans/miniconda3/envs/py-infer/lib/python3.12/site-packages/distributed/client.py", line 1382, in _ensure_connected
    msg = await wait_for(comm.read(), timeout)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jans/miniconda3/envs/py-infer/lib/python3.12/site-packages/distributed/utils.py", line 1935, in wait_for
    return await fut
           ^^^^^^^^^
  File "/home/jans/miniconda3/envs/py-infer/lib/python3.12/site-packages/distributed/comm/tcp.py", line 225, in read
    frames_nosplit_nbytes_bin = await str

finished


CancelledError: finalize-803844ca3d7dbce0fb18b228e1f2c77a

### 2. GIES

1. Reshape Count matrix
2. Run GIES


GIES Matrix Format - collected by intervention locations :
- data: n_interventions x n_samples/intervention (->take min.) x n_variables
- Intervention: 1 x n_intervention

Data Distribution:
- scale to mean 0 & std 1

-> intervened values <<0

In [7]:
run_GIES = False
data_GIES = True
if run_GIES or data_GIES:
    import sys
    sys.path.insert(0, './gies-master')

    import gies

In [8]:
if run_GIES or data_GIES:
    # Create a data matrix and a list of interventions (has to be homogenous)
    intervention_list, data_matrix = create_data_matrix_gies(adata)

    # Look at results
    print(adata.obs['gene_pert'].sum(), " gene perturbations")
    print(len(intervention_list), " interventions")
    print("Intervention list: ", intervention_list[:15])

    print("")
    print("Data matrix:")
    print("Length of data matrix: ", len(data_matrix))

    length = np.array([])
    for sub_array in data_matrix:
        length = np.append(length, len(sub_array))

    print("Minimum length: ", np.min(length))
    print("Maximum length: ", np.max(length))
    print("Average length: ", np.mean(length))
    print("Total Samples: ", np.sum(length))
    print("Total interventional Samples: ", np.sum(length[1:]))

    print("Entries per Intervention: ", length)

    # Conversions to get the data into the right format

    # OLD: downsizing sample to create regular numpy matrix
    min_length = int(np.min(length))
    data_matrix = [sub_array[:min_length] for sub_array in data_matrix]

    data_matrix = np.array(data_matrix,dtype=int)
    print("GIES final data shape: ",np.shape(data_matrix))

Intervention List created:  11 unique perturbations
1921  gene perturbations
11  interventions
Intervention list:  [[], [0], [1], [183], [184], [185], [186], [187], [188], [189], [190]]

Data matrix:
Length of data matrix:  11
Minimum length:  92.0
Maximum length:  1717.0
Average length:  330.72727272727275
Total Samples:  3638.0
Total interventional Samples:  1921.0
Entries per Intervention:  [1717.  228.  194.   92.  269.  220.  270.  203.  183.  154.  108.]
GIES final data shape:  (11, 92, 191)


In [9]:
# Save the data if it should be used externally
if data_GIES:
    np.save("./data/temp/gies_data_matrix.npy", data_matrix)

    import json

    with open("./data/temp/gies_intervention_list.json", 'w') as f:
        # indent=2 is not needed but makes the file human-readable 
        # if the data is nested
        json.dump(intervention_list, f, indent=2) 

In [10]:
# Run GIES
if run_GIES:
    estimate, score = gies.fit_bic(data_matrix, intervention_list, A0 = None)
    print(estimate)

### 3. DCDI

In [11]:
run_DCDI = False
if run_DCDI:
    import os
    import argparse
    import cdt
    import torch
    import numpy as np

    import sys

Detecting CUDA device(s) : [0]


In [12]:
if run_DCDI:
    current_dir = os.path.abspath(".")

    print("Current dir: ", current_dir)
    sys.path.append(os.path.join(current_dir, 'dcdi_implementation'))
    sys.path.append(os.path.join(current_dir, 'dcdi_implementation/dcdi_master/dcdi'))

    print(sys.path)

Current dir:  /gpfs/bwfor/home/hd/hd_hd/hd_pi226/code-python/py-infer
['/gpfs/bwfor/home/hd/hd_hd/hd_pi226/code-python/py-infer', '/home/hd/hd_hd/hd_pi226/.conda/envs/pytorch/lib/python311.zip', '/home/hd/hd_hd/hd_pi226/.conda/envs/pytorch/lib/python3.11', '/home/hd/hd_hd/hd_pi226/.conda/envs/pytorch/lib/python3.11/lib-dynload', '', '/home/hd/hd_hd/hd_pi226/.conda/envs/pytorch/lib/python3.11/site-packages', '/gpfs/bwfor/home/hd/hd_hd/hd_pi226/code-python/py-infer/dcdi_implementation', '/gpfs/bwfor/home/hd/hd_hd/hd_pi226/code-python/py-infer/dcdi_implementation/dcdi_master/dcdi']


In [13]:
if run_DCDI:
    import dcdi_master as dcdi
    from dcdi_master.dcdi.models.learnables import LearnableModel_NonLinGaussANM
    from dcdi_master.dcdi.models.flows import DeepSigmoidalFlowModel
    from dcdi_master.dcdi.train import train, retrain, compute_loss
    from dcdi_master.dcdi.data import DataManagerFile
    from dcdi_master.dcdi.utils.save import dump

    from dcdi_load import DataManagerAnndata

In [14]:
def _print_metrics(stage, step, metrics, throttle=None):
    for k, v in metrics.items():
        print("    %s:" % k, v)

def file_exists(prefix, suffix):
    return os.path.exists(os.path.join(prefix, suffix))

In [15]:
if run_DCDI:
    """
    Parameters for the DCDI algorithm

    store parameters as attributes of opt
    """

    opt = argparse.Namespace()
    # experiment
    opt.exp_path = './dcdi_implementation/exp_10genes_100k'  # Path to experiments
    opt.train = True            # Run `train` function, get /train folder
    opt.retrain = False         # after to-dag or pruning, retrain model from scratch before reporting nll-val
    opt.dag_for_retrain = None  # path to a DAG in .npy format which will be used for retrainig. e.g.  /code/stuff/DAG.npy
    opt.random_seed = 42        # Random seed for pytorch and numpy

    # data
    opt.data_path = None        # Path to data files
    opt.i_dataset = None        # dataset index
    opt.num_vars = len(adata.var_names)            # Number of variables
    opt.train_samples = 0.8     # Number of samples used for training (default is 80% of the total size)
    opt.test_samples = None     # Number of samples used for testing (default is whatever is not used for training)
    opt.num_folds = 5           # number of folds for cross-validation
    opt.fold = 0                # fold we should use for testing
    opt.train_batch_size = 64   # number of samples in a minibatch
    opt.num_train_iter = 1000000 # number of meta gradient steps
    opt.normalize_data = False  # (x - mu) / std
    opt.regimes_to_ignore = None # When loading data, will remove some regimes from data set
    opt.test_on_new_regimes = False # When using --regimes-to-ignore, we evaluate performance on new regimes never seen during training (use after retraining).

    # model
    opt.model = 'DCDI-G'        # model class (DCDI-G or DCDI-DSF)
    opt.num_layers = 2          # number of hidden layers
    opt.hid_dim = 16            # number of hidden units per layer
    opt.nonlin = 'leaky-relu'   # leaky-relu | sigmoid
    opt.flow_num_layers = 2     # number of hidden layers of the DSF
    opt.flow_hid_dim = 16       # number of hidden units of the DSF

    # intervention  
    opt.intervention = True     # Use data with intervention
    opt.dcd = False             # Use DCD (DCDI with a loss not taking into account the intervention)
    opt.intervention_type = "imperfect" # Type of intervention: perfect or imperfect
    opt.intervention_knowledge = "known" # If the targets of the intervention are known or unknown
    opt.coeff_interv_sparsity = 1e-8 # Coefficient of the regularisation in the unknown interventions case (lambda_R)

    # optimization
    opt.optimizer = "rmsprop"   # sgd|rmsprop
    opt.lr = 1e-3               # learning rate for optim
    opt.lr_reinit = None        # Learning rate for optim after first subproblem. Default mode reuses --lr.
    opt.lr_schedule = None      # Learning rate for optim, change initial lr as a function of mu: None|sqrt-mu|log-mu
    opt.stop_crit_win = 100     # window size to compute stopping criterion
    opt.reg_coeff = 0.1         # regularization coefficient (lambda)

    # Augmented Lagrangian options
    opt.omega_gamma = 1e-4      # Precision to declare convergence of subproblems
    opt.omega_mu = 0.9          # After subproblem solved, h should have reduced by this ratio
    opt.mu_init = 1e-8          # initial value of mu
    opt.mu_mult_factor = 2      # Multiply mu by this amount when constraint not sufficiently decreasing
    opt.gamma_init = 0.         # initial value of gamma
    opt.h_threshold = 1e-8      # Stop when |h|<X. Zero means stop AL procedure only when h==0

    # misc
    opt.patience = 10           # Early stopping patience in --retrain.
    opt.train_patience = 5      # Early stopping patience in --train after constraint
    opt.train_patience_post = 5 # Early stopping patience in --train after threshold

    # logging
    opt.plot_freq = 100       # plotting frequency
    opt.no_w_adjs_log = False   # do not log weighted adjacency (to save RAM). One plot will be missing (A_\phi plot)
    opt.plot_density = False    # Plot density (only implemented for 2 vars)

    # device and numerical precision
    opt.gpu = True              # Use GPU
    opt.float = False           # Use Float precision

    plotting_callback = None

In [None]:
if run_DCDI:
    # Control as much randomness as possible
    torch.manual_seed(opt.random_seed)
    np.random.seed(opt.random_seed)

    if opt.lr_reinit is not None:
        assert opt.lr_schedule is None, "--lr-reinit and --lr-schedule are mutually exclusive"

    # Initialize metric logger if needed
    metrics_callback = _print_metrics

    # adjust some default hparams
    if opt.lr_reinit is None: opt.lr_reinit = opt.lr

    # Use GPU
    if opt.gpu:
        if opt.float:
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        else:
            torch.set_default_tensor_type('torch.cuda.DoubleTensor')
    else:
        if opt.float:
            torch.set_default_tensor_type('torch.FloatTensor')
        else:
            torch.set_default_tensor_type('torch.DoubleTensor')





    # create DataManager for training
    train_data = DataManagerAnndata(opt.data_path, adata, opt.train_samples, opt.test_samples, train=True,
                                    normalize=opt.normalize_data,
                                    random_seed=opt.random_seed,
                                    intervention=opt.intervention,
                                    intervention_knowledge=opt.intervention_knowledge,
                                    dcd=opt.dcd)
    test_data = DataManagerAnndata(opt.data_path, adata, opt.train_samples, opt.test_samples, train=False,
                                normalize=opt.normalize_data, mean=train_data.mean, std=train_data.std,
                                random_seed=opt.random_seed,
                                intervention=opt.intervention,
                                intervention_knowledge=opt.intervention_knowledge,
                                dcd=opt.dcd)

    # create learning model and ground truth model
    if opt.model == "DCDI-G":
        model = LearnableModel_NonLinGaussANM(opt.num_vars,
                                                opt.num_layers,
                                                opt.hid_dim,
                                                nonlin=opt.nonlin,
                                                intervention=opt.intervention,
                                                intervention_type=opt.intervention_type,
                                                intervention_knowledge=opt.intervention_knowledge,
                                                num_regimes=train_data.num_regimes)
    elif opt.model == "DCDI-DSF":
        model = DeepSigmoidalFlowModel(num_vars=opt.num_vars,
                                        cond_n_layers=opt.num_layers,
                                        cond_hid_dim=opt.hid_dim,
                                        cond_nonlin=opt.nonlin,
                                        flow_n_layers=opt.flow_num_layers,
                                        flow_hid_dim=opt.flow_hid_dim,
                                        intervention=opt.intervention,
                                        intervention_type=opt.intervention_type,
                                        intervention_knowledge=opt.intervention_knowledge,
                                        num_regimes=train_data.num_regimes)
    else:
        raise ValueError("opt.model has to be in {DCDI-G, DCDI-DSF}")

    # print device of samples, masks and regimes
    print("train_data.adjacency.device:", train_data.adjacency.device)
    print("train_data.asmples.device:", train_data.gt_interv.device)
    #print("train_data.regimes.device:", train_data.regimes.device)



    # train until constraint is sufficiently close to being satisfied
    if opt.train:
        train(model, train_data.adjacency.detach().cpu().numpy(),
                train_data.gt_interv, train_data, test_data, opt, metrics_callback,
                plotting_callback)

    elif opt.retrain:
        initial_dag = np.load(opt.dag_for_retrain)
        model.adjacency[:, :] = torch.as_tensor(initial_dag).type(torch.Tensor)
        best_model = retrain(model, train_data, test_data, "ignored_regimes", opt, metrics_callback, plotting_callback)

    # Evaluate on ignored regimes!
    if opt.test_on_new_regimes:
        all_regimes = train_data.all_regimes

        # take all data, but ignore data on which we trained (want to test on unseen regime)
        regimes_to_ignore = np.setdiff1d(all_regimes, np.array(opt.regimes_to_ignore))
        new_data = DataManagerFile(opt.data_path, opt.i_dataset, 1., None, train=True,
                                    normalize=opt.normalize_data,
                                    random_seed=opt.random_seed,
                                    intervention=opt.intervention,
                                    intervention_knowledge=opt.intervention_knowledge,
                                    dcd=opt.dcd,
                                    regimes_to_ignore=regimes_to_ignore)

        with torch.no_grad():
            weights, biases, extra_params = best_model.get_parameters(mode="wbx")

            # evaluate on train
            x, masks, regimes = train_data.sample(train_data.num_samples)
            loss_train, mean_std_train = compute_loss(x, masks, regimes, best_model, weights, biases, extra_params,
                                                    intervention=True, intervention_type='structural',
                                                    intervention_knowledge="known", mean_std=True)

            # evaluate on valid
            x, masks, regimes = test_data.sample(test_data.num_samples)
            loss_test, mean_std_test = compute_loss(x, masks, regimes, best_model, weights, biases, extra_params,
                                                    intervention=True, intervention_type='structural',
                                                    intervention_knowledge="known", mean_std=True)

            # evaluate on new intervention
            x, masks, regimes = new_data.sample(new_data.num_samples)
            loss_new, mean_std_new = compute_loss(x, masks, regimes, best_model, weights, biases, extra_params,
                                                    intervention=True, intervention_type='structural',
                                                    intervention_knowledge="known", mean_std=True)

            # logging final result
            metrics_callback(stage="test_on_new_regimes", step=0,
                                metrics={"log_likelihood_train": - loss_train.item(),
                                        "mean_std_train": mean_std_train.item(),
                                        "log_likelihood_test": - loss_test.item(),
                                        "mean_std_test": mean_std_test.item(),
                                        "log_likelihood_new": - loss_new.item(),
                                        "mean_std_new": mean_std_new.item()}, throttle=False)


torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at /opt/conda/conda-bld/pytorch_1708025842427/work/torch/csrc/tensor/python_tensor.cpp:451.)


train_data.adjacency.device: cuda:0
train_data.asmples.device: cuda:0


This overload of addcmul_ is deprecated:
	addcmul_(Number value, Tensor tensor1, Tensor tensor2)
Consider using one of the following signatures instead:
	addcmul_(Tensor tensor1, Tensor tensor2, *, Number value) (Triggered internally at /opt/conda/conda-bld/pytorch_1708025842427/work/torch/csrc/utils/python_arg_parser.cpp:1630.)


Iteration: 0
    aug-lagrangian: 1.119273979904423
    aug-lagrangian-moving-avg: 0.011192739799044229
    aug-lagrangian-val: 1.0866025605861198
    nll: 1.0204633207028044
    nll-val: 0.9877919013845015
    nll-gap: -0.032671419318302974
    grad-norm-moving-average: 0.0006982246306931542
    delta_gamma: -inf
    omega_gamma: 0.0001
    delta_mu: inf
    omega_mu: 0.9
    constraint_violation: 0.2803709655805758
    acyclicity_violation: 0.2803709655805758
    mu: 1e-08
    gamma: 0.0
    initial_lr: 0.001
    current_lr: 0.001
    is_acyclic: 0
    true_edges: 0.0
Iteration: 100
    aug-lagrangian: 0.9805561837946822
    aug-lagrangian-moving-avg: 0.6339985285309501
    aug-lagrangian-val: 0.9398410514190427
    nll: 0.8818083912561688
    nll-val: 0.8410932588805293
    nll-gap: -0.04071513237563951
    grad-norm-moving-average: 0.045579947221820065
    delta_gamma: -inf
    omega_gamma: 0.0001
    delta_mu: inf
    omega_mu: 0.9
    constraint_violation: 0.24864810465723017
    

In [None]:
if run_DCDI:
    DCDI_matrix =model.adjacency.detach().cpu().numpy()
    print(np.shape(DCDI_matrix))
    print(DCDI_matrix)
    fig, ax = plt.subplots()
    fig1 = ax.matshow(DCDI_matrix)
    plt.colorbar(fig1)
    plt.title("DCDI: Adjacency matrix")
    plt.plot()