In [1]:
from pathlib import Path

import pyarrow.dataset as ds
import pyarrow as pa
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from sklearn.linear_model import LinearRegression

from lib.phenotype.constants import DEFAULT_METADATA_COLS
from lib.aggregate.align import prepare_alignment_data, centerscale_on_controls
from lib.aggregate.cell_data_utils import load_metadata_cols, split_cell_data

In [2]:
filtered_data_dir = Path(
    "/lab/ops_analysis/cheeseman/denali-analysis/analysis/brieflow_output/aggregate/parquets"
)
filtered_data_paths = list(
    filtered_data_dir.glob("*_CeCl-all_ChCo-DAPI_CENPA__filtered.parquet")
)

filtered_data_paths[:5]

[PosixPath('/lab/ops_analysis/cheeseman/denali-analysis/analysis/brieflow_output/aggregate/parquets/P-1_W-A1_CeCl-all_ChCo-DAPI_CENPA__filtered.parquet'),
 PosixPath('/lab/ops_analysis/cheeseman/denali-analysis/analysis/brieflow_output/aggregate/parquets/P-3_W-B3_CeCl-all_ChCo-DAPI_CENPA__filtered.parquet'),
 PosixPath('/lab/ops_analysis/cheeseman/denali-analysis/analysis/brieflow_output/aggregate/parquets/P-1_W-B3_CeCl-all_ChCo-DAPI_CENPA__filtered.parquet'),
 PosixPath('/lab/ops_analysis/cheeseman/denali-analysis/analysis/brieflow_output/aggregate/parquets/P-1_W-A3_CeCl-all_ChCo-DAPI_CENPA__filtered.parquet'),
 PosixPath('/lab/ops_analysis/cheeseman/denali-analysis/analysis/brieflow_output/aggregate/parquets/P-2_W-A1_CeCl-all_ChCo-DAPI_CENPA__filtered.parquet')]

In [3]:
filtered_dataset = ds.dataset(filtered_data_paths[:1], format="parquet")
filtered_dataset = filtered_dataset.to_table(
    use_threads=True, memory_pool=None
).to_pandas()
filtered_dataset

Unnamed: 0,plate,well,tile,cell_0,i_0,j_0,site,cell_1,i_1,j_1,...,cytoplasm_zernike_9_1,cytoplasm_zernike_9_3,cytoplasm_zernike_9_5,cytoplasm_zernike_9_7,cytoplasm_zernike_9_9,cytoplasm_number_neighbors_1,cytoplasm_percent_touching_1,cytoplasm_first_neighbor_distance,cytoplasm_second_neighbor_distance,cytoplasm_angle_between_neighbors
0,1,A1,1160,603,1480.918072,1475.887550,271,2073,671.410256,100.538462,...,0.169824,0.153202,0.123852,0.109058,0.048195,0,0.000000,60.189160,68.498827,69.299229
1,1,A1,1185,548,1484.316222,1478.959617,300,228,101.153846,670.461538,...,0.211755,0.210846,0.075870,0.044210,0.046080,0,0.000000,67.295378,68.514281,178.202148
2,1,A1,1448,558,1480.565083,1486.317149,351,1846,671.774648,103.760563,...,0.126162,0.096632,0.010390,0.051202,0.015212,1,0.087719,39.535027,55.081105,79.910721
3,1,A1,411,436,1480.081002,1486.503497,106,231,100.990991,672.018018,...,0.177218,0.158275,0.055382,0.007802,0.013903,0,0.000000,57.437006,82.084600,107.190937
4,1,A1,66,465,1478.631083,1473.532967,14,230,99.607143,96.666667,...,0.112981,0.093791,0.041740,0.045495,0.026236,0,0.000000,61.219186,69.652386,54.631190
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
379518,1,A1,1343,20,98.849755,35.512793,305,2359,895.500000,312.500000,...,0.167412,0.137871,0.055244,0.047822,0.021110,0,0.000000,141.557476,188.759464,18.475536
379519,1,A1,1084,4,42.036866,88.276498,258,943,311.324561,893.535088,...,0.271584,0.084160,0.160228,0.014167,0.015366,0,0.000000,66.290674,89.802533,102.855450
379520,1,A1,1290,17,87.868983,42.111166,288,2305,893.556452,313.564516,...,0.091929,0.193869,0.083789,0.026767,0.035904,0,0.000000,65.774531,81.612313,97.183674
379521,1,A1,1166,4,32.591142,94.303594,269,830,308.265306,895.612245,...,0.167841,0.119827,0.022149,0.049574,0.028167,0,0.000000,58.319881,175.937500,78.187734


In [4]:
filtered_dataset.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 379523 entries, 0 to 379522
Columns: 898 entries, plate to cytoplasm_angle_between_neighbors
dtypes: Int64(50), bool(1), float32(1), float64(794), int64(48), object(1), string(3)
memory usage: 2.6+ GB


In [5]:
!free -h

              total        used        free      shared  buff/cache   available
Mem:           31Gi        26Gi       2.2Gi       1.0Mi       2.5Gi       4.3Gi
Swap:         974Mi       974Mi       0.0Ki


In [None]:
GENE = "RPS8"  # high perturbation score example
# GENE = "SKA3"  # low n, high perturbation score example
# GENE = "PTP4A2"  # low perturbation score example

subset_dfs = []
for filtered_data_path in filtered_data_paths:
    print(f"Loading {filtered_data_path}")
    filtered_dataset = ds.dataset(filtered_data_path, format="parquet")
    perturbation_col = filtered_dataset.to_table(columns=["gene_symbol_0"]).to_pandas()[
        "gene_symbol_0"
    ]

    gene_indices = perturbation_col.str.contains(GENE, na=False).to_numpy().nonzero()[0]
    nontargeting_indices = (
        perturbation_col.str.contains("nontargeting", na=False).to_numpy().nonzero()[0]
    )
    nontargeting_indices = np.random.choice(
        nontargeting_indices, size=len(gene_indices), replace=False
    )
    combined_indices = np.union1d(gene_indices, nontargeting_indices)

    subset_df = filtered_dataset.scanner(use_threads=True).take(
        pa.array(combined_indices)
    )
    subset_df = subset_df.to_pandas(use_threads=True, memory_pool=None).reset_index(
        drop=True
    )

    subset_dfs.append(subset_df)

subset_df = pd.concat(subset_dfs, ignore_index=True)
subset_df

In [None]:
metadata_cols = DEFAULT_METADATA_COLS + ["class", "confidence"]
feature_cols = subset_df.columns.difference(metadata_cols, sort=False)

metadata, features = split_cell_data(subset_df, metadata_cols)
metadata, features = prepare_alignment_data(
    metadata,
    features,
    ["plate", "well"],
    "gene_symbol_0",
    "nontargeting",
    "sgRNA_0",
)
features = features.astype(np.float32)

features = centerscale_on_controls(
    features,
    metadata,
    "gene_symbol_0",
    "nontargeting",
    "batch_values",
)
features = pd.DataFrame(features, columns=feature_cols)

subset_df_scaled = pd.concat([metadata, features], axis=1)
subset_df_scaled

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import StratifiedKFold, cross_val_predict
from sklearn.feature_selection import SelectKBest, f_classif
from sklearn.metrics import roc_auc_score


def get_perturbation_score(
    cell_data, gene, feature_cols, n_differential_features=200, auroc_cutoff=0.6
):
    """Per-cell perturbation scores via 5-fold out-of-fold logistic regression with top-k feature selection.

    AUROC guide:
      - < 0.6  → basically noise; don’t filter (return NaN scores and keep all cells)
      - 0.6–0.75 → weak/moderate separation; filter cautiously
      - > 0.75 → decent separation; filtering makes sense
      - > 0.85–0.9 → strong separation; filtering always safe and effective
    """
    y = (cell_data["gene_symbol_0"] == gene).astype(int).to_numpy()
    X_all = cell_data[feature_cols].to_numpy()

    # select top-k differential features (ANOVA F-test)
    k = min(n_differential_features, X_all.shape[1])
    selector = SelectKBest(score_func=f_classif, k=k).fit(X_all, y)
    X = selector.transform(X_all)

    clf = LogisticRegression(max_iter=2000, class_weight="balanced", solver="liblinear")
    cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=0)
    scores = cross_val_predict(clf, X, y, cv=cv, method="predict_proba")[:, 1]

    auc = roc_auc_score(y, scores)
    if auc < auroc_cutoff:
        # no meaningful separation → return NaNs so downstream keeps all cells
        return pd.Series(np.nan, index=cell_data.index), auc

    return pd.Series(scores, index=cell_data.index), auc


perturbation_scores, auc = get_perturbation_score(
    subset_df_scaled, GENE, feature_cols, 100
)
print(auc)
subset_df_scaled["perturbation_score"] = perturbation_scores
subset_df_scaled

In [None]:
plot_df = pd.concat(
    [
        subset_df_scaled[subset_df_scaled["gene_symbol_0"] != GENE].assign(
            group="control cells"
        ),
        subset_df_scaled[subset_df_scaled["gene_symbol_0"] == GENE].assign(
            group=f"{GENE} cells"
        ),
    ]
)

plt.figure(figsize=(6, 4))
sns.histplot(
    data=plot_df,
    x="perturbation_score",
    hue="group",
    bins=50,
    element="step",
    fill=True,
    stat="count",
    common_norm=False,
    alpha=0.5,
)
plt.xlabel("Perturbation score")
plt.ylabel("Cell count")
sns.despine()
plt.tight_layout()
plt.show()

In [None]:
import umap

# Prepare data
X = subset_df_scaled[feature_cols]
control_X = subset_df_scaled[
    subset_df_scaled["gene_symbol_0"].str.startswith("nontargeting")
][feature_cols]
is_gene = subset_df_scaled["gene_symbol_0"] == GENE

# UMAP
embedding = umap.UMAP(n_jobs=-1).fit_transform(X)

# Plot
plt.figure(figsize=(8, 5))
plt.scatter(
    embedding[~is_gene, 0],
    embedding[~is_gene, 1],
    color="gray",
    s=5,
    alpha=0.3,
    label="nontargeting",
)
plt.scatter(
    embedding[is_gene, 0],
    embedding[is_gene, 1],
    color="orange",
    s=5,
    alpha=0.9,
    label=GENE,
)
plt.legend()
plt.title(f"UMAP of Cells Colored by Gene ({GENE})")
plt.xlabel("UMAP1")
plt.ylabel("UMAP2")
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(8, 5))
plt.scatter(
    embedding[~is_gene, 0],
    embedding[~is_gene, 1],
    c="gray",
    s=5,
    alpha=0.3,
    label="nontargeting",
)
plt.scatter(
    embedding[is_gene, 0],
    embedding[is_gene, 1],
    c=subset_df_scaled[is_gene]["perturbation_score"],
    cmap="coolwarm",
    s=5,
    alpha=0.9,
    label=GENE,
)
plt.legend()
plt.title(f"UMAP Colored by Projection-Norm Perturbation Score ({GENE})")
plt.xlabel("UMAP1")
plt.ylabel("UMAP2")
plt.colorbar(label="Perturbation Score (0-1)")
plt.tight_layout()
plt.show()