In [21]:
import sys
import os
import torch
import numpy as np
import pandas as pd

In [5]:
k562_data = torch.load("/home/justinhong/data/k562-100.pt")

In [6]:
k562_data

{'expression_matrix': array([[0.        , 4.116087  , 1.6989243 , ..., 2.0414922 , 3.9585767 ,
         1.1737294 ],
        [0.        , 4.264735  , 0.7240735 , ..., 0.7240735 , 4.1708293 ,
         0.        ],
        [0.53746116, 4.0835495 , 1.1426188 , ..., 0.885135  , 3.825033  ,
         0.53746116],
        ...,
        [0.        , 3.445676  , 1.7846454 , ..., 1.7846454 , 3.5217388 ,
         0.4822307 ],
        [0.        , 3.8232505 , 1.7644103 , ..., 0.        , 4.078512  ,
         0.        ],
        [0.        , 3.347689  , 1.6845286 , ..., 0.        , 3.7015288 ,
         0.7407354 ]], dtype=float32),
 'interventions': ['ENSG00000117360',
  'ENSG00000164134',
  'ENSG00000135775',
  'excluded',
  'ENSG00000180992',
  'ENSG00000058729',
  'ENSG00000115884',
  'ENSG00000103549',
  'ENSG00000072849',
  'ENSG00000198755',
  'ENSG00000148297',
  'ENSG00000154473',
  'ENSG00000116560',
  'non-targeting',
  'ENSG00000085840',
  'ENSG00000102978',
  'ENSG00000087365',
  'ENSG0

In [7]:
k562_data["expression_matrix"].shape

(130200, 622)

In [15]:
len(np.unique(k562_data["interventions"]))

641

In [18]:
# rename all excluded and missing interventions as observational
gene_names = k562_data["gene_names"]
orig_interventions = np.array(k562_data["interventions"], dtype="object")
perturbation_label = orig_interventions.copy()
for intervention in np.unique(k562_data["interventions"]):
    if intervention in ("non-targeting", "excluded") or intervention not in gene_names:
        perturbation_label[orig_interventions == intervention] = "obs"
        

In [26]:
X_df = pd.DataFrame(np.hstack((k562_data["expression_matrix"], perturbation_label.reshape((-1, 1)))), columns = gene_names + ["perturbation_label"])
X_df

Unnamed: 0,ENSG00000116809,ENSG00000142676,ENSG00000188529,ENSG00000133226,ENSG00000090273,ENSG00000142784,ENSG00000117748,ENSG00000126698,ENSG00000180198,ENSG00000116560,...,ENSG00000101901,ENSG00000125352,ENSG00000101882,ENSG00000125676,ENSG00000134597,ENSG00000147274,ENSG00000102030,ENSG00000147403,ENSG00000071553,perturbation_label
0,0.0,4.116087,1.698924,2.041492,0.750007,0.0,1.173729,1.173729,0.0,2.041492,...,0.0,0.0,0.750007,0.750007,0.0,1.470416,2.041492,3.958577,1.173729,ENSG00000117360
1,0.0,4.264735,0.724074,1.139639,1.432333,0.724074,0.0,0.724074,0.0,1.842785,...,0.0,0.0,0.0,1.432333,0.0,1.432333,0.724074,4.170829,0.0,ENSG00000164134
2,0.537461,4.083549,1.142619,0.885135,2.002142,0.0,0.0,1.347196,0.885135,2.327426,...,0.537461,0.0,0.885135,1.142619,0.0,1.516945,0.885135,3.825033,0.537461,ENSG00000135775
3,0.581381,4.184453,0.581381,0.0,0.946631,0.0,0.581381,1.213582,0.581381,1.597878,...,0.0,0.0,0.0,0.0,0.0,0.581381,0.0,2.488159,0.0,obs
4,0.0,3.70225,1.201707,1.731935,1.731935,0.0,0.771447,0.771447,0.0,1.731935,...,0.771447,0.0,0.771447,1.731935,0.771447,2.076608,0.0,3.861062,0.0,ENSG00000180992
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
130195,0.0,3.652522,1.166304,2.03213,1.875682,0.0,0.744341,1.690135,0.744341,2.286499,...,0.0,0.0,0.0,1.690135,1.166304,1.462135,1.166304,3.904441,0.0,ENSG00000149636
130196,0.0,3.758051,1.870334,0.863993,1.870334,0.0,1.118063,1.118063,0.522541,2.223076,...,0.0,0.0,0.0,0.522541,0.863993,1.758574,1.320483,3.906565,0.0,ENSG00000164151
130197,0.0,3.445676,1.784645,1.410601,2.331683,0.0,0.0,1.674811,1.050489,1.883602,...,0.0,0.0,0.806193,1.410601,0.482231,1.551406,1.784645,3.521739,0.482231,ENSG00000008988
130198,0.0,3.823251,1.76441,2.247783,1.229367,0.0,0.0,1.532254,0.0,2.368022,...,0.792781,0.792781,0.792781,1.229367,0.0,0.792781,0.0,4.078512,0.0,ENSG00000029364


In [27]:
X_df.to_csv("/home/justinhong/data/cleaned_k562.csv")

In [None]:
X_df = pd.read_csv("/home/justinhong/data/cleaned_k562.csv")

## Run SDCI

In [29]:
sys.path.append("../")
from models import SDCI
from train_utils import create_intervention_dataset

In [30]:
dataset = create_intervention_dataset(X_df, regime_format=False)

In [None]:
model = SDCI()
model.train(dataset, log_wandb=True, wandb_project="cb-perturb-seq")
print(f"Ran in {model._train_runtime_in_sec // 60} minutes.")

wandb.finish()

adj_mtx = model.get_adjacency_matrix()
np.savetext("/home/justinhong/results/cb_perturb_seq_adj_mtx.csv", adj_mtx, delimiter=",")