Filtered Data from Authors: K562_essential_raw_singlecell_01.h5ad
Unfiltered Data from Jasleen: K562_essential_mtx.h5ad



In [3]:
! uv pip install pyarrow

[2mUsing Python 3.12.7 environment at: /home/mgill/code/cz-benchmarks/.venv[0m
[2mAudited [1m1 package[0m [2min 6ms[0m[0m


In [1]:
# %reload_ext autoreload
# %autoreload 2
# %aimport czbenchmarks.datasets.utils_single_cell
import czbenchmarks.datasets.utils_single_cell as scp

import os
import scanpy as sc
import pandas as pd
import anndata as ad
import numpy as np
from pandas.testing import assert_frame_equal

In [2]:
deg_name = 'wilcoxon'

## Load Data

In [24]:
filtered_data_path = f"{os.environ['HOME']}/.cz-benchmarks/datasets/replogle_k562_essential_perturbpredict_de_results_control_cells.h5ad"
# filtered_data_path = "/data2/czbenchmarks/replogle2022/K562_essential_raw_singlecell_01.h5ad"
adata_filtered = ad.read_h5ad(filtered_data_path, backed=None)

control_cells_ids = adata_filtered.uns["control_cells_ids"]
control_cells_ids.pop("non-targeting")

array([], dtype=float64)

In [4]:
# de_results = pd.DataFrame(adata_filtered.uns["de_results_wilcoxon"])

## Run DGE Analysis

In [6]:
print(deg_name)

results = scp.run_multicondition_dge_analysis(
    adata=adata_filtered,
    condition_key="condition",
    control_name="non-targeting",
    control_cells_ids=control_cells_ids,
    deg_test_name=deg_name,
    filter_min_cells=0,  # 0, 10
    filter_min_genes=0,  # 0, 1000
    min_pert_cells=1,  # 1, 50
    remove_avg_zeros=False,
    return_merged_adata=False,
)

results.to_parquet(f"/data2/compute_de_results_{deg_name}.arrow")

wilcoxon


Processing de conditions: 100%|██████████████████████████████████████████████████████████████| 2057/2057 [05:35<00:00,  6.14item/s, Completed 2057/2057]


In [7]:
results.head()

Unnamed: 0,condition,gene_id,score,logfoldchange,pval,pval_adj
0,ENSG00000001497,ENSG00000112306,4.758876,0.262604,2e-06,0.002778
1,ENSG00000001497,ENSG00000265972,4.707202,1.077903,3e-06,0.003072
2,ENSG00000001497,ENSG00000187837,4.311571,0.915918,1.6e-05,0.012619
3,ENSG00000001497,ENSG00000196262,4.114563,0.257904,3.9e-05,0.025552
4,ENSG00000001497,ENSG00000084207,4.041896,0.344575,5.3e-05,0.028376


## Or Load from Disk

In [8]:
# results = pd.read_parquet(f"/data2/compute_de_results_{deg_name}.arrow")

In [26]:
conditions = list(set(results.condition))

In [27]:
results.shape, adata_filtered.n_vars, len(conditions), adata_filtered.n_vars * len(conditions)

((17614091, 6), 8563, 2057, 17614091)

In [28]:
2057 * 8563

17614091

In [29]:
results_counts = (
    results[["condition", "gene_id"]].groupby("condition").count()
)
results_counts.head()

Unnamed: 0_level_0,gene_id
condition,Unnamed: 1_level_1
ENSG00000001497,8563
ENSG00000003509,8563
ENSG00000004779,8563
ENSG00000004897,8563
ENSG00000005007,8563


In [32]:
(results_counts == adata_filtered.n_vars).all()

gene_id    True
dtype: bool

# FIXME update to compare to existing validation file
## Similarity Comparison -- Create Condition - Gene Id Mapper

In [9]:
for deg_test_name in de_keys:
    print(deg_test_name, results[deg_test_name].shape, de_results[deg_test_name].shape)
    print("condition", results[deg_test_name]['condition'].nunique(), de_results[deg_test_name]['condition'].nunique())
    print("gene_id", results[deg_test_name]['gene_id'].nunique(), de_results[deg_test_name]['gene_id'].nunique())

wilcoxon (19223723, 6) (20793007, 7)
condition 1832 2057
gene_id 15730 15730


In [10]:
# Create a column with condition / gene_id to filter the data for similarity
def zip_names(df):
    condition = df['condition'].astype(str).values
    gene_id = df['gene_id'].astype(str).values
    mapper = list(map(lambda x: '_'.join([x[0], x[1]]), zip(condition, gene_id)))
    return mapper

for deg_test_name in de_keys:
    results[deg_test_name]['mapper'] = zip_names(results[deg_test_name])
    de_results[deg_test_name]['mapper'] = zip_names(de_results[deg_test_name])

In [11]:
shared_condition_genes = {}
for deg_test_name in de_keys:
    shared_condition_genes[deg_test_name] = set(results[deg_test_name]['mapper']).intersection(set(de_results[deg_test_name]['mapper']))
    shared_condition_genes[deg_test_name] = list(shared_condition_genes[deg_test_name])

In [12]:
type(shared_condition_genes[deg_test_name]), len(shared_condition_genes[deg_test_name])

(list, 19223723)

In [13]:
s_results = {}
s_de_results = {}

for deg_test_name in de_keys:
    s_results[deg_test_name] = results[deg_test_name].set_index('mapper').loc[shared_condition_genes[deg_test_name]].reset_index(drop=True)
    s_de_results[deg_test_name] = de_results[deg_test_name].set_index('mapper').loc[shared_condition_genes[deg_test_name]].reset_index(drop=True)

## Compare Common Results

In [14]:
common_columns = {}
for deg_test_name in de_keys:
    print(deg_test_name)
    common_columns[deg_test_name] = set(s_de_results[deg_test_name].columns).intersection(s_results[deg_test_name].columns)
    common_columns[deg_test_name] = list(sorted(common_columns[deg_test_name]))
    old_columns = [c for c in s_de_results[deg_test_name].columns if c not in common_columns[deg_test_name]]
    new_columns = [c for c in s_results[deg_test_name].columns if c not in common_columns[deg_test_name]]
    
    with pd.option_context('display.width', 10000, 'display.max_columns', 100):
        print(s_de_results[deg_test_name][common_columns[deg_test_name] + old_columns].head())
        print(s_results[deg_test_name][common_columns[deg_test_name] + new_columns].head())

wilcoxon
         condition          gene_id  logfoldchange      pval  pval_adj     score condition_name
0  ENSG00000115866  ENSG00000113356       0.022121  0.950154  0.999435  0.062514           DARS
1  ENSG00000155666  ENSG00000273802      -0.247788  0.560620  0.999799 -0.581920           KDM8
2  ENSG00000109606  ENSG00000175221      -1.007420  0.000532  0.004872 -3.463998          DHX15
3  ENSG00000095139  ENSG00000136811       0.087385  0.683991  1.000000  0.407024          ARCN1
4  ENSG00000169045  ENSG00000119421       0.177865  0.172091  0.999798  1.365515        HNRNPH1
         condition          gene_id  logfoldchange      pval  pval_adj     score
0  ENSG00000115866  ENSG00000113356       0.022121  0.950154  0.999435  0.062514
1  ENSG00000155666  ENSG00000273802      -0.247788  0.560620  0.999799 -0.581920
2  ENSG00000109606  ENSG00000175221      -1.007420  0.000532  0.004872 -3.463998
3  ENSG00000095139  ENSG00000136811       0.087385  0.683991  1.000000  0.407024
4  ENSG000

In [27]:
# These are the result of running analysis on the raw data
orig_wilcoxon_counts = {}
new_wilcoxon_counts = {}

for deg_test_name in de_keys:
    orig_wilcoxon_counts[deg_test_name] = (
        s_de_results[deg_test_name][["condition", "gene_id"]].groupby("condition").count()
    )
    new_wilcoxon_counts[deg_test_name] = (
        s_results[deg_test_name][["condition", "gene_id"]].groupby("condition").count()
    )
    
    print(deg_test_name)
    print(len(orig_wilcoxon_counts[deg_test_name]), len(new_wilcoxon_counts[deg_test_name]))
    
    unique_conditions = sorted(list(set(orig_wilcoxon_counts[deg_test_name].index).intersection(new_wilcoxon_counts[deg_test_name].index)))
    orig_wilcoxon_counts[deg_test_name] = orig_wilcoxon_counts[deg_test_name].loc[unique_conditions]
    new_wilcoxon_counts[deg_test_name] = new_wilcoxon_counts[deg_test_name].loc[unique_conditions]
    
    print(orig_wilcoxon_counts[deg_test_name].head())
    print(new_wilcoxon_counts[deg_test_name].head())

  s_de_results[deg_test_name][["condition", "gene_id"]].groupby("condition").count()


wilcoxon
2058 1832
                 gene_id
condition               
ENSG00000001497     9889
ENSG00000004779    10746
ENSG00000004897    12173
ENSG00000005007    10546
ENSG00000005100    10502
                 gene_id
condition               
ENSG00000001497     9889
ENSG00000004779    10746
ENSG00000004897    12173
ENSG00000005007    10546
ENSG00000005100    10502


In [28]:
for deg_test_name in de_keys:
    print(deg_test_name)
    assert_frame_equal(
        orig_wilcoxon_counts[deg_test_name],
        new_wilcoxon_counts[deg_test_name],
        check_dtype=False,
        check_index_type=False,
        check_categorical=False,
    )

wilcoxon


In [32]:
for deg_test_name in de_keys:
    print(deg_test_name)
    print(len(s_de_results[deg_test_name][common_columns[deg_test_name]]), len(s_results[deg_test_name][common_columns[deg_test_name]]))
    assert_frame_equal(
        s_de_results[deg_test_name][common_columns[deg_test_name]],
        s_results[deg_test_name][common_columns[deg_test_name]],
        check_dtype=False,
        check_index_type=False,
        check_categorical=False,
        atol=1e-4
    )

wilcoxon
19223723 19223723


AssertionError: DataFrame.iloc[:, 3] (column name="pval") are different

DataFrame.iloc[:, 3] (column name="pval") values are different (0.03325 %)
[index]: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, ...]
[left]:  [0.950153857543032, 0.5606204865261339, 0.0005322111935425, 0.68399051781166, 0.1720911505157198, 0.9969392219482772, 0.921392773980092, 0.1580455948894672, 0.2602954984994338, 0.7368714326045523, 0.1733138552747071, 0.3151932431314119, 0.5810950290404722, 0.6713197654014221, 0.6597510146580998, 0.7870172193578943, 0.5790260246423479, 0.2404660934721755, 0.5720442708822588, 0.6902298653876919, 0.7927452855893585, 0.5999966170648937, 0.2868328644263079, 0.907324530682686, 0.7752774772057947, 0.7389080890267076, 0.081259987112111, 0.7921304483278214, 0.1373683733986759, 0.2183111049881834, 0.5751590692274258, 0.9834997589324436, 0.2080682435896055, 0.872649858035315, 0.3232766648052418, 2.2625010940509723e-05, 0.8911803538742676, 0.4378710424070492, 0.1692640974874255, 0.7381774270738973, 0.9532747017471532, 0.1184512110854366, 0.3705507162250756, 0.9057129520424524, 0.1389880609863609, 0.6029631868104641, 0.0233769616463758, 0.6326359969185638, 0.9409485023014384, 0.7042751261420694, 0.1899671425568566, 0.724440873260489, 0.3683695634332092, 0.1172403851072333, 0.114424549843245, 0.023946092028222, 0.8812813459080352, 0.2467362672672604, 0.0154900678808812, 0.9618555847602762, 0.511136946879874, 0.0799456575308412, 0.7419269482798236, 0.0245995263296245, 0.1079089420678407, 0.7075336390228358, 0.7263012089668777, 0.2260850647418633, 0.8027375258313723, 0.5706434857453426, 0.8510373106226896, 0.6585340748669761, 0.0739283022866417, 0.8641054876405486, 0.7187465460906216, 0.6804745538746109, 0.920495451373839, 0.811285365964405, 0.9956471053076132, 0.8965170394131275, 0.501626436800819, 0.4934062707425712, 0.3557015393888678, 0.8821319542880073, 0.0042301204305526, 0.6707348068982417, 0.3988366323347574, 0.4468728207108308, 0.6612682645516857, 0.8364121328558082, 0.6587495209582304, 0.7686634621428239, 0.862328532223954, 0.5364376416864287, 0.7565451517255931, 0.9680527435154468, 0.944903236436636, 0.8977021146900537, 0.4660533144060884, 0.4769168733811332, ...]
[right]: [0.950153857543032, 0.5606204865261339, 0.0005322111935425807, 0.68399051781166, 0.17209115051571988, 0.9969392219482773, 0.921392773980092, 0.15804559488946723, 0.2602954984994338, 0.7368714326045523, 0.1733138552747071, 0.31519324313141195, 0.5810950290404722, 0.6713197654014221, 0.6597510146580998, 0.7870172193578943, 0.5790260246423479, 0.2404660934721755, 0.5720442708822588, 0.6902298653876919, 0.7927452855893585, 0.5999966170648937, 0.28683286442630795, 0.9073245306826859, 0.7752774772057947, 0.7389080890267076, 0.08125998711211105, 0.7921304483278214, 0.13736837339867594, 0.2183111049881834, 0.5751590692274258, 0.9834997589324436, 0.20806824358960552, 0.872649858035315, 0.32327666480524186, 2.2625010940509723e-05, 0.8911803538742676, 0.4378710424070492, 0.16926409748742555, 0.7381774270738973, 0.9532747017471533, 0.11845121108543669, 0.3705507162250756, 0.9057129520424525, 0.1389880609863609, 0.6029631868104641, 0.023376961646375816, 0.6326359969185638, 0.9409485023014384, 0.7042751261420694, 0.18996714255685665, 0.724440873260489, 0.36836956343320926, 0.11724038510723335, 0.11442454984324502, 0.023946092028222076, 0.8812813459080352, 0.2467362672672604, 0.015490067880881222, 0.9618555847602762, 0.511136946879874, 0.07994565753084129, 0.7419269482798236, 0.024599526329624557, 0.1079089420678407, 0.7075336390228358, 0.7263012089668777, 0.2260850647418633, 0.8027375258313723, 0.5706434857453426, 0.8510373106226896, 0.6585340748669761, 0.0739283022866417, 0.8641054876405486, 0.7187465460906216, 0.6804745538746109, 0.920495451373839, 0.811285365964405, 0.9956471053076132, 0.8965170394131275, 0.501626436800819, 0.49340627074257126, 0.35570153938886784, 0.8821319542880073, 0.00423012043055267, 0.6707348068982417, 0.3988366323347574, 0.4468728207108308, 0.6612682645516857, 0.8364121328558082, 0.6587495209582304, 0.7686634621428239, 0.862328532223954, 0.5364376416864287, 0.7565451517255931, 0.9680527435154469, 0.9449032364366359, 0.8977021146900537, 0.46605331440608844, 0.4769168733811332, ...]
At positional index 1247, first diff: 0.540048433041013 != 0.5406201181652415

## Misc Checks for Other Files

In [None]:
# Verify Target Conditions to Save
import json

with open("../target_conditions_to_save.json", "r") as f:
    target_conditions_to_save = json.load(f)

condition_list = set(map(lambda x: x.split("_")[1], target_conditions_to_save.keys()))

target_conditions_to_save_new = {}
for key, value in target_conditions_to_save.items():
    condition = key.split("_")[1]
    if condition not in target_conditions_to_save_new:
        target_conditions_to_save_new[condition] = value
    else:
        assert target_conditions_to_save_new[condition] == value


In [None]:
# Validate that existing control cells ids 
skipped_conditions = []
# Convert index to set for O(1) lookup instead of O(n) for each check
index_set = set(adata_filtered.obs.index.values)
for key, values in control_cells_ids.items():
    if key not in adata_filtered.obs.condition:
        skipped_conditions.append(key)
    else:
        assert all(x in index_set for x in values), f"{values} not in {adata_filtered.obs.index.values}"

print(f"Skipped conditions: {skipped_conditions}")
