In [None]:
import pandas as pd
import numpy as np
import seaborn as sb
import os
import umap
import time
import gc

from tqdm import tqdm
import matplotlib.pyplot as plt

import profiling

In [None]:
PROJECT_ROOT = "/raid/data/cellpainting/TAORF/"
EXP = "cp_dataset"
NUM_FEATURES = 672

OUTPUT_BASE_NAME = 'efn128_cp_dataset'

PERT_NAME = "pert_name"
CTRL_NAME = "EMPTY_"

CORR_MATRIX = "efn128new_cpdataset_train_1e-2_e30.csv"

REG_PARAM = 1e-2

In [None]:
def find_first_hits(features, feats, pert_name, control):
    results = []

    gen = features[features[pert_name] != control].groupby(["Metadata_Plate", "Metadata_Well", pert_name])["val"].count().reset_index().iterrows()
    for k, r in tqdm(gen):
        # Select samples in a well
        s = time.time()
        well = features.query(f"Metadata_Plate == {r.Metadata_Plate} & Metadata_Well == '{r.Metadata_Well}'").index
        A = np.asarray(features.loc[well,feats])

        # Get cells in other wells
        others = features.query(f"Metadata_Plate != {r.Metadata_Plate} | Metadata_Well != '{r.Metadata_Well}'").index
        B = np.asarray(features.loc[others,feats])

        # Compute cosine similarity
        C = np.dot(A, B.T)
        An = np.linalg.norm(A, axis=1)
        Bn = np.linalg.norm(B, axis=1)
        cos = C / (An[:,np.newaxis] @ Bn[:,np.newaxis].T)

        # Rank cells in other wells
        ranking = np.argsort(-cos, axis=1)

        # Find first hits
        H = np.asarray(features.loc[others, pert_name]  == r[pert_name], dtype=np.uint8)
        for h in range(len(well)):
            hit = np.where(H[ranking[h]] == 1)[0][0]
            results.append({"Metadata_Plate":r.Metadata_Plate, 
                            "Metadata_Well":r.Metadata_Well, 
                            "pert_name":r.pert_name,
                            "first_hit": hit,
                           })
    return pd.DataFrame(data=results)

In [None]:
def summarize(results):
    summary = results.groupby([PERT_NAME])["first_hit"].mean().reset_index()
    summary["std"] = results.groupby([PERT_NAME])["first_hit"].std().reset_index()["first_hit"]
    summary["top_percent"] = (summary["first_hit"] / len(results))*100
    summary["percent_group"] = np.ceil(summary["top_percent"])
    summary["coef_var"] = summary["std"] / summary["first_hit"]
    summary["signal_noise"] = summary["first_hit"] / summary["std"]
    return summary

In [None]:
def visualize(summary):
    plt.figure(figsize=(10,5))
    summary = summary.sort_values("first_hit",na_position='last')
    sb.barplot(data=summary, x=PERT_NAME, y="top_percent")
    print("Treatments with hits in the top 1%:", summary[summary["top_percent"] <= 1].shape[0])
    plt.show()
    return summary

In [None]:
# Load metadata
metadata = pd.read_csv(os.path.join(PROJECT_ROOT, "inputs/metadata/index_taorf_minus2wells.csv"))
Y = pd.read_csv("../data/TAORF_MOA_MATCHES.csv")
profiles = pd.merge(metadata, Y, left_on="broad_sample", right_on="Var1")
meta = pd.concat((profiles, metadata[metadata[PERT_NAME] == CTRL_NAME]), axis=0).reset_index()

In [None]:
features = []
for i in tqdm(meta.index):
    filename = PROJECT_ROOT + "outputs/" + EXP + "/features/{}/{}/{}.npz"
    filename = filename.format(
        meta.loc[i, "Metadata_Plate"], 
        meta.loc[i, "Metadata_Well"], 
        meta.loc[i, "Metadata_Site"],
    )
    if os.path.isfile(filename):
        with open(filename, "rb") as data:
            info = np.load(data)
            features.append(info["features"])
    else:
        features.append([])

In [None]:
total_single_cells = 0
for i in range(len(features)):
    if len(features[i]) > 0:
        total_single_cells += features[i].shape[0]

print("Total images",len(features),features[0].shape)
print("Total single cells:", total_single_cells)

In [None]:
feats = [i for i in range(NUM_FEATURES)]

In [None]:
site_level_data = []
site_level_features = []
for plate in tqdm(meta["Metadata_Plate"].unique()):
    m1 = meta["Metadata_Plate"] == plate
    wells = meta[m1]["Metadata_Well"].unique()
    for well in wells:
        result = meta.query("Metadata_Plate == '{}' and Metadata_Well == '{}'".format(plate, well))
        for i in result.index:
            if len(features[i]) == 0:
                continue
            num_features = features[i].shape[1]
            median_profile = np.median(features[i], axis=0)
            pert_name = result[PERT_NAME].unique()
            replicate = result["pert_name_replicate"].unique()
            broad_sample = result["broad_sample"].unique()
            val = result["val"].unique()
            site_level_data.append(
                {
                    "Metadata_Plate": plate,
                    "Metadata_Well": well,
                    PERT_NAME: pert_name[0],
                    "Replicate": replicate[0],
                    "broad_sample": broad_sample[0],
                    "val": val[0]
                }
            )
            site_level_features.append(median_profile)

In [None]:
del features
gc.collect()

In [None]:
columns1 = ["Metadata_Plate", "Metadata_Well", PERT_NAME, "Replicate", "broad_sample", "val"]

sites1 = pd.DataFrame(columns=columns1, data=site_level_data)
sites2 = pd.DataFrame(columns=feats, data=site_level_features)
sites = pd.concat([sites1, sites2], axis=1)

In [None]:
whN = profiling.WhiteningNormalizer(sites[feats][sites[PERT_NAME] == CTRL_NAME], reg_param=REG_PARAM)

In [None]:
sites_treatments = sites[sites[PERT_NAME] != CTRL_NAME]
sites_treatments.reset_index(inplace = True, drop = True)
whD = whN.normalize(sites_treatments[feats])
sites_treatments[feats] = whD

In [None]:
img_results = find_first_hits(sites_treatments, feats, PERT_NAME, CTRL_NAME)

In [None]:
img_summary = summarize(img_results)
img_summary = visualize(img_summary)
img_summary.to_csv(OUTPUT_BASE_NAME + '_image_level.csv')

In [None]:
wells = sites.groupby(["Metadata_Plate", "Metadata_Well", PERT_NAME]).mean().reset_index()

tmp = meta.groupby(["Metadata_Plate", "Metadata_Well", PERT_NAME, "broad_sample"])["DNA"].count().reset_index()
wells = pd.merge(wells, tmp, how="left", left_on=["Metadata_Plate", "Metadata_Well", PERT_NAME], right_on=["Metadata_Plate", "Metadata_Well", PERT_NAME])

wells = wells[columns1 + feats]

In [None]:
whN = profiling.WhiteningNormalizer(wells.loc[wells[PERT_NAME] == CTRL_NAME, feats], REG_PARAM)

In [None]:
wells_treatments = wells[wells[PERT_NAME] != CTRL_NAME]
wells_treatments.reset_index(inplace = True, drop = True)
whD = whN.normalize(wells_treatments[feats])
wells_treatments[feats] = whD

In [None]:
well_results = find_first_hits(wells_treatments, feats, PERT_NAME, CTRL_NAME)

In [None]:
well_summary = summarize(well_results)
well_summary = visualize(well_summary)
well_summary.to_csv(OUTPUT_BASE_NAME + '_well_level.csv')

In [None]:
treatment_features = wells_treatments.groupby(["pert_name", "broad_sample"]).mean().reset_index()
treatment_features = treatment_features.sort_values("broad_sample").reset_index(drop=True)

In [None]:
# Ground truth connections
X, Y = profiling.load_correlation_matrix(CORR_MATRIX)
Y["broad_sample"] = treatment_features.broad_sample
print("Treatments with ground truth:", np.sum(Y.broad_sample == Y.Var1))

moa_matches = []
for k,r in Y.iterrows():
    m = Y["Metadata_moa.x"] == r["Metadata_moa.x"]
    moa_matches.append(m)
moa_matches = np.asarray(moa_matches, dtype=np.uint8)
plt.imshow(moa_matches)

In [None]:
# Similarity search
F = treatment_features.loc[treatment_features[PERT_NAME] != CTRL_NAME, feats]
C = np.dot(F, F.T)
Fn = np.linalg.norm(F, axis=1)
cos = C / (Fn[:,np.newaxis] @ Fn[:,np.newaxis].T)
ranking = np.argsort(-cos, axis=1)

In [None]:
# First hits evaluation
results = []
for h in range(cos.shape[0]):
    hit = np.where(moa_matches[h, ranking[h]] == 1)[0]
    if len(hit) >= 2:
        hit = hit[1]
    else:
        hit = 0
        print(h, Y.loc[h, "Metadata_moa.x"])
    results.append({"broad_sample":treatment_features.loc[h, "broad_sample"], 
                    "pert_name":treatment_features.loc[h, "pert_name"],
                    "first_hit": hit,
                   })
results = pd.DataFrame(data=results)

In [None]:
# Summary statistics per treatment
results["top_percent"] = (results["first_hit"] / len(results))*100
treatment_summary = visualize(results)
treatment_summary.to_csv(OUTPUT_BASE_NAME + '_treatment_level.csv')

In [None]:
# Summary per MOA
moa_results = pd.merge(results, Y, on="broad_sample").groupby("Metadata_moa.x")["first_hit"].mean().reset_index()
moa_results["top_percent"] = (moa_results["first_hit"] / len(results))*100
moa_results = moa_results.sort_values("top_percent")
sb.barplot(data=moa_results, x="Metadata_moa.x", y="top_percent")
moa_results.to_csv(OUTPUT_BASE_NAME + '_moa.csv')