In [1]:
# %reload_ext autoreload
# %autoreload 2
# %aimport czbenchmarks.datasets.utils_single_cell
import czbenchmarks.datasets.utils_single_cell as scp

import os
import scanpy as sc
import pandas as pd
import anndata as ad
import numpy as np
from pandas.testing import assert_frame_equal
import matplotlib.pyplot as plt

In [2]:
deg_name = 'wilcoxon'

## Load Data

In [3]:
filtered_data_path = f"{os.environ['HOME']}/.cz-benchmarks/datasets/replogle_k562_essential_perturbpredict_de_results_control_cells.h5ad"
# filtered_data_path = "/data2/czbenchmarks/replogle2022/K562_essential_raw_singlecell_01.h5ad"
adata_filtered = ad.read_h5ad(filtered_data_path, backed=None)

In [4]:
de_results = pd.DataFrame(adata_filtered.uns[f"de_results_{deg_name}"])
# de_results['condition'] = de_results['condition'].cat.remove_unused_categories()
de_results['condition'] = de_results['condition'].astype(str)

control_cells_ids = adata_filtered.uns["control_cells_ids"]
control_cells_ids.pop("non-targeting", None)

array([], dtype=float64)

In [5]:
cr4_data_path = (
    "/data2/czbenchmarks/replogle2022/raw_h5ad_from_cr4/K562_essential_mtx.h5ad"
)
adata_cr4 = ad.read_h5ad(cr4_data_path, backed=None)
adata_cr4.var.index.name = "gene_id"
adata_cr4.var.rename(columns={"gene_name": "gene"}, inplace=True)
adata_cr4.obs.rename(columns={"gene_id": "condition"}, inplace=True)

## Compare Keys and Shapes

In [6]:
adata_cr4.shape, adata_filtered.shape, de_results.shape

((310385, 32628), (310385, 8563), (20793007, 7))

In [7]:
# The only missing condition is non-targeting (i.e. the control)
print(deg_name)
print(
    set(adata_filtered.obs.condition.unique())
    - set(de_results.condition.unique())
)
print(
    set(de_results.condition.unique())
    - set(adata_filtered.obs.condition.unique())
)

wilcoxon
{'non-targeting'}
set()


## Run DGE Analysis

In [None]:
print(deg_name)
# num_conditions = 10
# condition_list = np.asarray(list(control_cells_ids.keys()))
# condition_list = np.random.choice(condition_list, size=num_conditions, replace=False)
# new_control_cells_ids = {k: control_cells_ids[k] for k in condition_list}
new_control_cells_ids = control_cells_ids 
min_pert_cells=1 # 1, 50

results, conditions_filtered = scp.run_multicondition_dge_analysis(
    adata=adata_cr4,
    condition_key="condition",
    control_name="non-targeting",
    control_cells_ids=new_control_cells_ids,
    deg_test_name=deg_name,
    filter_min_cells=10,  # 0, 10
    filter_min_genes=10,  # 0, 1000
    min_pert_cells=min_pert_cells,  # 1, 50
    remove_avg_zeros=False,
    return_merged_adata=False,
)

results.to_parquet(f"/data2/compare_de_results_{deg_name}_min_pert_cells_{min_pert_cells}.arrow")



wilcoxon


Processing de conditions:   1%|▌                                                                 | 19/2057 [00:15<08:20,  4.07item/s, Completed 19/2057]

In [None]:
len(conditions_filtered), conditions_filtered

## Or load from disk

In [None]:
# results = pd.read_parquet(f"/data2/compare_de_results_{deg_name}.arrow")

In [None]:
print(deg_name, results.shape, de_results.shape)

In [None]:
32628 * 2057

## Define Intersection / Unique Sets

In [None]:
# Create a column with condition / gene_id to filter the data for similarity
def zip_names(df):
    condition = df['condition'].astype(str).values
    gene_id = df['gene_id'].astype(str).values
    mapper = list(map(lambda x: '_'.join([x[0], x[1]]), zip(condition, gene_id)))
    return mapper

results['mapper'] = zip_names(results)
de_results['mapper'] = zip_names(de_results)

print(results.shape, de_results.shape)
print("condition", results['condition'].nunique(), de_results['condition'].nunique())
print("gene_id", results['gene_id'].nunique(), de_results['gene_id'].nunique())
print(len(results['mapper']), len(de_results['mapper']))

In [None]:
# All generated condition/gene pairs are a part of the previous generated condition/gene pairs
print(set(results['mapper']).issubset(set(de_results['mapper'])))
print(set(results['condition']).issubset(set(de_results['condition'])))

In [None]:
unique_condition_gene_pairs = list(set(de_results['mapper']) - set(results['mapper']))
shared_condition_gene_pairs = list(set(de_results['mapper']).intersection(set(results['mapper'])))

unique_conditions = list(set(de_results['condition']) - set(results['condition']))
shared_conditions = list(set(de_results['condition']).intersection(set(results['condition'])))

print(len(unique_condition_gene_pairs), len(shared_condition_gene_pairs))
print(len(unique_conditions), len(shared_conditions))

In [None]:
import json
with open(f"/data2/unique_conditions_{deg_name}.json", 'w') as fh:
    json.dump(unique_conditions, fh)

In [None]:
unique_condition_gene_df = pd.DataFrame(map(lambda x: x.split('_'), unique_condition_gene_pairs))
if unique_condition_gene_df.shape[1] > 0:
    unique_condition_gene_df.columns = ['condition', 'gene_id']

print(len(unique_condition_gene_df))
if len(unique_condition_gene_df) > 0:
    unique_condition_gene_df.head()
else:
    unique_condition_gene_df = None

In [None]:
common_columns = sorted(set(de_results.columns).intersection(results.columns))
de_columns = [c for c in de_results.columns if c not in common_columns]
r_columns  = [c for c in results.columns if c not in common_columns]

In [None]:
s_de_results = de_results[common_columns].set_index('mapper').loc[shared_condition_gene_pairs]
s_results = results[common_columns].set_index('mapper').loc[shared_condition_gene_pairs]

## Examine Differences

In [None]:
de_results_unique_counts = (
    de_results
        .set_index('condition')
        .loc[unique_conditions]
        .groupby(level=0)
        .count()
        .describe()
)
de_results_unique_counts

## Compare Common Conditions / Gene IDs

In [None]:
  with pd.option_context('display.width', 10000, 'display.max_columns', 100):
    print(s_de_results.head())
    print(s_results.head())

In [None]:
s_de_results_counts = (
    s_de_results
        .set_index('condition')
        ['gene_id']
        .groupby(level=0)
        .count()
)

s_results_counts = (
    s_results
        .set_index('condition')
        ['gene_id']
        .groupby(level=0)
        .count()
)

print(s_de_results_counts.head())
print(s_results_counts.head())

In [None]:
print(assert_frame_equal(
    s_de_results_counts.to_frame(),
    s_results_counts.to_frame(),
))

In [None]:
assert_frame_equal(
    s_de_results,
    s_results,
    check_dtype=False,
    check_index_type=False,
    check_categorical=False,
    atol=1.5e-2
)

In [None]:
diff_columns = ['logfoldchange', 'pval', 'pval_adj', 'score']
diff_results = (
    s_de_results[diff_columns] -
    s_results[diff_columns]
)
diff_results.head()

In [None]:
diff_results_describe = diff_results.describe()
diff_results_describe

In [None]:
axes = diff_results.head(n=1000).plot(kind='hist', subplots=True, layout=(2,2)) # , bins=100, sharex=False, 
# for pos,ax in enumerate(axes.flatten()):
#     ax.set_xlim(diff_results_describe.loc['min'].iloc[pos], diff_results_describe.loc['max'].iloc[pos])
#     ax.set(title=diff_results.columns[pos])

## Misc Checks for Other Files

In [None]:
# Verify Target Conditions to Save
import json

with open("../target_conditions_to_save.json", "r") as f:
    target_conditions_to_save = json.load(f)

condition_list = set(map(lambda x: x.split("_")[1], target_conditions_to_save.keys()))

target_conditions_to_save_new = {}
for key, value in target_conditions_to_save.items():
    condition = key.split("_")[1]
    if condition not in target_conditions_to_save_new:
        target_conditions_to_save_new[condition] = value
    else:
        assert target_conditions_to_save_new[condition] == value


In [None]:
# Validate that existing control cells ids exist in data
skipped_conditions = []
# Convert index to set for O(1) lookup instead of O(n) for each check
index_set = set(adata_filtered.obs.index.values)
for key, values in control_cells_ids.items():
    if key not in adata_filtered.obs.condition:
        skipped_conditions.append(key)
    else:
        assert all(x in index_set for x in values), f"{values} not in {adata_filtered.obs.index.values}"

print(f"Skipped conditions: {skipped_conditions}")
