In [None]:
import pandas as pd
import numpy as np
import seaborn as sb
import os
import sys
import time
import gc

from tqdm import tqdm
import matplotlib.pyplot as plt

sys.path.append("../profiling/")
import profiling

In [None]:
PROJECT_ROOT = "/raid/data/cellpainting/TAORF/"
EXP = "cp_dataset"
NUM_FEATURES = 672

OUTPUT_BASE_NAME = 'efn128_cp_dataset'

PERT_NAME = "pert_name"
CTRL_NAME = "EMPTY_"

REG_PARAM = 1e-2

In [None]:
def find_first_hits(features, feats, pert_name, control):
    results = []

    gen = features[features[pert_name] != control].groupby(["Metadata_Plate", "Metadata_Well", pert_name])["val"].count().reset_index().iterrows()
    for k, r in tqdm(gen):
        # Select samples in a well
        s = time.time()
        well = features.query(f"Metadata_Plate == {r.Metadata_Plate} & Metadata_Well == '{r.Metadata_Well}'").index
        A = np.asarray(features.loc[well,feats])

        # Get cells in other wells
        others = features.query(f"Metadata_Plate != {r.Metadata_Plate} | Metadata_Well != '{r.Metadata_Well}'").index
        B = np.asarray(features.loc[others,feats])

        # Compute cosine similarity
        C = np.dot(A, B.T)
        An = np.linalg.norm(A, axis=1)
        Bn = np.linalg.norm(B, axis=1)
        cos = C / (An[:,np.newaxis] @ Bn[:,np.newaxis].T)

        # Rank cells in other wells
        ranking = np.argsort(-cos, axis=1)

        # Find first hits
        H = np.asarray(features.loc[others, pert_name]  == r[pert_name], dtype=np.uint8)
        for h in range(len(well)):
            hit = np.where(H[ranking[h]] == 1)[0][0]
            results.append({"Metadata_Plate":r.Metadata_Plate, 
                            "Metadata_Well":r.Metadata_Well, 
                            "pert_name":r.pert_name,
                            "first_hit": hit,
                           })
    return pd.DataFrame(data=results)

In [None]:
def summarize(results):
    summary = results.groupby([PERT_NAME])["first_hit"].mean().reset_index()
    summary["std"] = results.groupby([PERT_NAME])["first_hit"].std().reset_index()["first_hit"]
    summary["top_percent"] = (summary["first_hit"] / len(results))*100
    summary["percent_group"] = np.ceil(summary["top_percent"])
    summary["coef_var"] = summary["std"] / summary["first_hit"]
    summary["signal_noise"] = summary["first_hit"] / summary["std"]
    return summary

In [None]:
def visualize(summary):
    plt.figure(figsize=(10,5))
    summary = summary.sort_values("first_hit",na_position='last')
    sb.barplot(data=summary, x=PERT_NAME, y="top_percent")
    print("Treatments with hits in the top 1%:", summary[summary["top_percent"] <= 1].shape[0])
    plt.show()
    return summary

In [None]:
# Load metadata
metadata = pd.read_csv(os.path.join(PROJECT_ROOT, "inputs/metadata/index_taorf_minus2wells.csv"))
Y = pd.read_csv("../data/TAORF_MOA_MATCHES.csv")
profiles = pd.merge(metadata, Y, left_on="broad_sample", right_on="Var1")
meta = pd.concat((profiles, metadata[metadata[PERT_NAME] == CTRL_NAME]), axis=0).reset_index()

In [None]:
features = []
for i in tqdm(meta.index):
    filename = PROJECT_ROOT + "outputs/" + EXP + "/features/{}/{}/{}.npz"
    filename = filename.format(
        meta.loc[i, "Metadata_Plate"], 
        meta.loc[i, "Metadata_Well"], 
        meta.loc[i, "Metadata_Site"],
    )
    if os.path.isfile(filename):
        with open(filename, "rb") as data:
            info = np.load(data)
            features.append(info["features"])
    else:
        features.append([])

In [None]:
total_single_cells = 0
for i in range(len(features)):
    if len(features[i]) > 0:
        total_single_cells += features[i].shape[0]

print("Total images",len(features),features[0].shape)
print("Total single cells:", total_single_cells)

In [None]:
cols = ["Metadata_Plate","Metadata_Well","Metadata_Site","pert_name","broad_sample","pert_name_replicate","val"]
sc_features = np.zeros((total_single_cells, features[0].shape[1]))
sc_meta_idx = []
k = 0
for i in tqdm(range(len(features))):
    cells = features[i].shape[0]
    sc_features[k:k+cells,:] = features[i]
    sc_meta_idx += [i]*cells
    k += cells

sc_meta = pd.merge(pd.DataFrame(sc_meta_idx, columns=["ID"]), meta[cols], left_on="ID", right_index=True)

In [None]:
del features
gc.collect()

In [None]:
sc_controls = sc_meta[sc_meta[PERT_NAME] == CTRL_NAME]
sc_treatments = sc_meta[sc_meta[PERT_NAME] != CTRL_NAME]

In [None]:
sc_control_features = pd.merge(sc_controls, pd.DataFrame(data=sc_features), how="left", left_index=True, right_index=True)

In [None]:
feats = [i for i in range(NUM_FEATURES)]

In [None]:
for i in range(10):
    sc_sample = []
    gen = sc_treatments.groupby(["Metadata_Plate", "Metadata_Well", "pert_name"])["val"].count().reset_index().iterrows()
    for k,r in tqdm(gen):
        sc_sample.append(sc_treatments.query(f"Metadata_Plate == {r.Metadata_Plate} & Metadata_Well == '{r.Metadata_Well}'").sample(10))
    sc_sample = pd.concat(sc_sample)
    whN = profiling.WhiteningNormalizer(sc_control_features[feats], reg_param=REG_PARAM)
    sc_sample_features = pd.merge(sc_sample, pd.DataFrame(data=sc_features), how="left", left_index=True, right_index=True)
    whD = whN.normalize(sc_sample_features[feats])
    sc_sample_features[feats] = whD
    sc_results = find_first_hits(sc_sample_features, feats, PERT_NAME, CTRL_NAME)
    sc_summary = summarize(sc_results)
    sc_summary = visualize(sc_summary)
    sc_summary.to_csv(OUTPUT_BASE_NAME + '_single_cell_level_sample_{}.csv'.format(i))