In [16]:
import numpy as np
import pandas as pd
import anndata as ad
from tqdm import tqdm
import torch
import os

In [4]:
expression = ad.read_h5ad("/hpc/projects/group.califano/GLM/data/_replogle/data/replogle_200k_clean/K562/cells.h5ad", backed="r")
# expression = ad.read_h5ad("/hpc/archives/group.califano/replogle.h5ad", backed="r")
expression

AnnData object with n_obs × n_vars = 199954 × 7948 backed at '/hpc/projects/group.califano/GLM/data/_replogle/data/replogle_200k_clean/K562/cells.h5ad'
    obs: 'gem_group', 'gene', 'gene_id', 'transcript', 'gene_transcript', 'sgID_AB', 'mitopercent', 'UMI_count', 'z_gemgroup_UMI', 'core_scale_factor', 'core_adjusted_UMI_count', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'n_counts', 'sample_id', 'cluster'
    var: 'gene_name', 'chr', 'start', 'end', 'class', 'strand', 'length', 'in_matrix', 'mean', 'std', 'cv', 'fano', 'mt', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts'
    uns: 'cluster', 'log1p', 'neighbors', 'pca'
    obsm: 'X_pca'
    varm: 'PCs'
    obsp: 'connectivities', 'distances'

In [5]:
perturbation_var = "gene_id"
# Below, any cells where perturbation_var is 'nan' - or the perturbed gene is not in the column genes - are discarded
gene_ids = expression.obs[perturbation_var] # Get the perturbed gene ids
control_mask = (gene_ids == "non-targeting") # Mask accounting for cells with non-targeting intervention
perturbed_mask = gene_ids.isin(expression.var.index).to_numpy() # Mask accounting for cells with perturbed genes that are included in our dataset

control_cells = expression[control_mask] # Get the control cells from the perturbation dataset
perturbed_cells = expression[perturbed_mask] # Get the perturbed cells from the perturbation dataset

print("Number of control cells:", control_cells.shape[0])
print("Number of perturbed cells:", perturbed_cells.shape[0])

Number of control cells: 7454
Number of perturbed cells: 147797


#### Get the average control gene expression for K562

In [6]:
ctrl_array = control_cells.to_memory().X.toarray()
ctrl_array.shape

(7454, 7948)

In [7]:
ctrl_mean = np.mean(ctrl_array, axis=0)
ctrl_mean.shape

(7948,)

#### Get the average perturbed gene expression for K562

In [8]:
# Convert to pandas DataFrames
perturbation_ENSGs = perturbed_cells.obs[perturbation_var].reset_index(drop=True) # Series object of perturbation for each cell (after filtering cells)
perturbation_ENSGs.unique().shape

(7645,)

In [35]:
mean_perturbed = {}
for prt in tqdm(perturbation_ENSGs.unique()):
    mask = (expression.obs["gene_id"] == prt)
    prt_array = expression[mask, :].to_memory().X.toarray()
    prt_mean = np.mean(prt_array, axis=0)
    mean_perturbed[prt] = prt_mean

  0%|          | 16/7645 [00:39<5:14:46,  2.48s/it]


KeyboardInterrupt: 

In [36]:
mean_perturbed

{'ENSG00000109861': array([3.6378477 , 0.5262999 , 0.4105049 , ..., 9.258563  , 0.15327616,
        1.399147  ], dtype=float32),
 'ENSG00000273559': array([3.5479207 , 0.6434493 , 0.5597913 , ..., 9.192645  , 0.23590085,
        1.6675168 ], dtype=float32),
 'ENSG00000178104': array([3.5178235 , 0.15732586, 0.8247604 , ..., 9.173493  , 0.72716874,
        1.3446441 ], dtype=float32),
 'ENSG00000074755': array([4.5588818 , 0.29358825, 1.6686676 , ..., 9.12903   , 1.0224719 ,
        0.89437866], dtype=float32),
 'ENSG00000143553': array([4.795863  , 0.        , 0.62443715, ..., 9.070008  , 1.1783478 ,
        0.54728687], dtype=float32),
 'ENSG00000004534': array([3.897406  , 0.7445368 , 0.23900257, ..., 9.24987   , 0.        ,
        1.0193276 ], dtype=float32),
 'ENSG00000005100': array([3.5411367 , 0.58324254, 0.7075229 , ..., 9.364246  , 1.1072527 ,
        1.2702887 ], dtype=float32),
 'ENSG00000129250': array([3.8026762, 0.6870104, 0.9667298, ..., 9.254552 , 1.3565712,
        1.

#### Iterate and find the lowest L2 norm

In [37]:
effect = {}
for prt in tqdm(mean_perturbed.keys()):
    effect_size = np.linalg.norm(mean_perturbed[prt] - ctrl_mean)
    effect[prt] = effect_size

100%|██████████| 16/16 [00:00<00:00, 70344.72it/s]


In [33]:
qq = mean_perturbed

In [39]:
sum(list(effect.values()))

571.4550552368164

In [40]:
effect_df = pd.DataFrame(effect)
effect_df

ValueError: If using all scalar values, you must pass an index

In [None]:
import json

my_dict = {"a": 1, "b": 2, "c": 3}

with open("output.json", "w") as f:
    json.dump(my_dict, f, indent=4)
