In [1]:
from causcell import CausCell
import warnings # need to install this package
warnings.filterwarnings("ignore")

# model training part

In [2]:
# set up a CausCell model
model = CausCell(save_and_sample_every=10)

# load Merfish_Brain dataset in Data directory
# load its concept list, concept value counts and causal structure between concepts
concept_list = ['Age','Domain','Celltype']
concept_counts = [3, 8, 10]
concept_cdag = [[0,0,0,0],[0,0,0,0],[1,1,0,0],[0,0,0,0]]

In [3]:
# set up an output directory of model training
results_folder = "../Output"
# train dataset format transformation for CausCell training
transformed_train_data = model.data_transformation(data_pwd="../Data/Merfish_Brain_training_data.h5ad", 
                                                   save_pwd="../Data", 
                                                   concept_list=concept_list,
                                                   log_norm=True)

In [4]:
# model training (We recommend 100,000 training steps)
model.train(training_data_pwd="../Data/transformed_Merfish_Brain_training_data.h5ad", 
            model_save_pwd="../Output", 
            concept_list=concept_list, concept_counts=concept_counts, concept_cdag=concept_cdag, 
            training_num_steps=1000, 
            train_log=False)

training completed


# model loading part

In [5]:
# set up an output directory of results
results_folder = "../Output"

# load trained model parameters from previous training
model.load_trained(concept_list=concept_list, concept_counts=concept_counts, concept_cdag=concept_cdag, 
                   results_folder=results_folder, 
                   trained_profile_size=374, 
                   milestone=100)

# testing data transformation

In [6]:
# test dataset format transformation for CausCell training
transformed_test_data = model.data_transformation(data_pwd="../Data/Merfish_Brain_testing_data.h5ad", 
                                                   save_pwd="../Data", 
                                                   concept_list=concept_list)
# set up the path of transformed test dataset
testing_data_pwd = "../Data/transformed_Merfish_Brain_testing_data.h5ad"

# obtain the concept representations and reconstructed cells in test dataset

In [7]:
# obtained the concept representations of all cells in test dataset
concept_embs = model.disentanglement(testing_data_pwd=testing_data_pwd, 
                                     saved_pwd="../Output", 
                                     concept_list=concept_list, concept_counts=concept_counts, concept_cdag=concept_cdag)

# obtained the reconstructed gene expression profiles of all cells in test dataset
generated_cells = model.get_generated_cells(testing_data_pwd=testing_data_pwd, saved_pwd="../Output", 
                                            concept_list=concept_list, concept_counts=concept_counts, concept_cdag=concept_cdag)
print(concept_embs)
print(generated_cells)

sampling loop time step: 100%|█████████████████████████████████████████████████████| 1000/1000 [00:03<00:00, 251.64it/s]
sampling loop time step: 100%|█████████████████████████████████████████████████████| 1000/1000 [00:05<00:00, 193.01it/s]
sampling loop time step: 100%|█████████████████████████████████████████████████████| 1000/1000 [00:05<00:00, 186.82it/s]
sampling loop time step: 100%|█████████████████████████████████████████████████████| 1000/1000 [00:05<00:00, 198.15it/s]
sampling loop time step: 100%|█████████████████████████████████████████████████████| 1000/1000 [00:03<00:00, 269.48it/s]
sampling loop time step: 100%|█████████████████████████████████████████████████████| 1000/1000 [00:03<00:00, 266.14it/s]
sampling loop time step: 100%|█████████████████████████████████████████████████████| 1000/1000 [00:05<00:00, 193.33it/s]
sampling loop time step: 100%|█████████████████████████████████████████████████████| 1000/1000 [00:05<00:00, 185.64it/s]
sampling loop time step: 100%|██

[[[-0.03774261  0.09436035 -0.28049928 ... -2.2432618  -2.1854491
    0.13760376]
  [-0.6255661   0.6060425  -0.59526366 ...  0.22362633 -0.9994141
    0.38067627]
  [-0.71817017  0.46741027 -0.47032624 ... -1.0249268  -0.66607666
    0.74243164]
  ...
  [-1.2075195  -1.0665039  -2.7210937  ... -2.1920898  -3.6007812
    1.7867187 ]
  [-1.4721069  -0.84521484 -1.5464844  ... -1.2559814  -1.8638672
    0.73168945]
  [-1.2072265   0.04457779 -2.3513672  ... -1.8014648  -4.2328124
    1.0612793 ]]

 [[-0.14243469  0.6914886   0.49133912 ... -0.19420166  0.27676392
    1.2553711 ]
  [-0.55852664  1.7215332   0.2387909  ...  1.3014648   0.31305543
    0.6621338 ]
  [-0.79126585  1.3989258   0.36828613 ...  0.22573853 -0.2587158
    0.9105713 ]
  ...
  [-1.3381836   0.3678711   0.04198761 ... -0.5686798  -0.16378784
    1.2751465 ]
  [-0.81572264  2.1604493   0.65339357 ... -0.842807    0.33544922
    0.5028931 ]
  [-1.2440917   1.3043334  -0.73083496 ...  0.12379456  1.0036377
    1.4610351

# counterfactual generation part

In [8]:
# set up counterfactual intervention targets
multi_target_list = [
    {"target_factor": "Age", "ref_factor_value":0, "tgt_factor_value": 2}, 
]

# obtain the counterfactual generated cells based on the intervened concepts
counterfactual_generated_cells = model.counterfactual_generation(data_pwd="../Data/Merfish_Brain_testing_data.h5ad", 
                                                                 save_pwd='../Output', 
                                                                 concept_list=concept_list, concept_counts=concept_counts, concept_cdag=concept_cdag, 
                                                                 multi_target_list=multi_target_list, 
                                                                 file_name="Counterfactual_generated_cells",
                                                                 batch_size=1000)
print(counterfactual_generated_cells)

sampling loop time step: 100%|█████████████████████████████████████████████████████| 1000/1000 [00:03<00:00, 264.35it/s]
sampling loop time step: 100%|█████████████████████████████████████████████████████| 1000/1000 [00:04<00:00, 243.75it/s]
sampling loop time step: 100%|█████████████████████████████████████████████████████| 1000/1000 [00:04<00:00, 214.90it/s]
sampling loop time step: 100%|█████████████████████████████████████████████████████| 1000/1000 [00:04<00:00, 246.03it/s]
sampling loop time step: 100%|█████████████████████████████████████████████████████| 1000/1000 [00:04<00:00, 244.36it/s]
sampling loop time step: 100%|█████████████████████████████████████████████████████| 1000/1000 [00:04<00:00, 212.71it/s]
sampling loop time step: 100%|█████████████████████████████████████████████████████| 1000/1000 [00:04<00:00, 210.95it/s]
sampling loop time step: 100%|█████████████████████████████████████████████████████| 1000/1000 [00:04<00:00, 212.57it/s]
sampling loop time step: 100%|██

AnnData object with n_obs × n_vars = 25002 × 374
    obs: 'Age', 'Batch', 'Domain', 'Celltype', 'Type'
