In [1]:
import os
import io
import random
import pandas as pd
import numpy as np
import pickle as pkl
from numpy import random
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import accuracy_score
from sklearn.manifold import TSNE
from matplotlib import pyplot as plt
import colorcet as cc
import seaborn as sns
import numba
import time
sns.set_style(style='white') 
sns.set(rc={'figure.figsize':(12,8)})
palette = sns.color_palette("bright", 15)
palette2 = sns.color_palette("dark", 10)
palette3 = sns.color_palette(cc.glasbey, n_colors=20)

from umap.umap_ import UMAP
from collections import Counter, defaultdict


from process import read_file_embeddings, plate_wise_spherize_and_normailize, average_wells

0.10.3


In [2]:
def read_file_embeddings(cpfname, fname, f_dim=512, feature_cols="micon_", is_moa=False):
    if cpfname.split('.')[-1] == 'parquet':
        df_check = pd.read_parquet(cpfname)
    elif cpfname.split('.')[-1] == 'csv':
        df_check = pd.read_csv(cpfname, low_memory=False)
    df_check["Metadata_Fov"] = df_check["Metadata_Fov"].astype(int)
    if is_moa:
        df_check["Metadata_Moa"] = df_check["Metadata_InChIKey"].apply(lambda x: get_moa(x)).tolist()
    with open(fname, "rb") as f:
        emb, fname = pkl.load(f)

    f_name = []
    for x in fname:
        [f_name.extend([t.split("$")]) for t in x]

    df_emb = pd.DataFrame({"Metadata_Plate": [x[2] for x in f_name], "Metadata_Well": [x[3] for x in f_name], "Metadata_Fov": [int(x[4]) for x in f_name]})
    df_feat = pd.DataFrame(data=emb, columns=[f"{feature_cols}{i}" for i in range(f_dim)])
    df_emb = pd.concat([df_emb, df_feat], axis=1)
    df_check = df_check.merge(df_emb, on=["Metadata_Plate", "Metadata_Well", "Metadata_Fov"])
    return df_check

In [3]:
SMI2LABEL = {'c1ccc(-c2nn3c(c2-c2ccnc4cc(OCCN5CCOCC5)ccc24)CCC3)nc1': 1,
 'COc1ncc2cc(C(=O)Nc3cc(C(O)=NCc4cccc(Cl)c4)ccc3Cl)c(O)nc2n1': 2,
 'CC1CC2C3CC=C4CC(=O)C=CC4(C)C3(F)C(O)CC2(C)C1(O)C(=O)CO': 3,
 'C=CC1CN2CCC1CC2C(O)c1ccnc2ccc(OC)cc12': 4,
 'CCOC(=O)C1OC1C(O)=NC(CC(C)C)C(O)=NCCC(C)C': 5,
 'Cc1csc(-c2nnc(Nc3ccc(Oc4ncccc4-c4cc[nH]c(=N)n4)cc3)c3ccccc23)c1': 6,
 'O=C(c1ccccc1)N1CCC(CCCCN=C(O)C=Cc2cccnc2)CC1': 7,
 'CC(C)N=C(O)N1CCC(N=C2Nc3cc(F)ccc3N(CC(F)F)c3ccc(Cl)cc32)C1': 8,
 'CS(C)=O': 'control'}
SOURCE_LIST = ['source_2',  'source_3', 'source_5', 'source_6', 'source_7', 'source_8', 'source_11']

In [4]:
def generate_source_split(X, Y, Y_s, Y_b, source):
    X_train, Y_train, Y_s_train, Y_b_train = X[Y_s != source], Y[Y_s != source], Y_s[Y_s != source], Y_b[Y_s != source]
    X_test, Y_test, Y_s_test, Y_b_test = X[Y_s == source], Y[Y_s == source], Y_s[Y_s == source], Y_b[Y_s == source]
    
    return X_train, X_test, (Y_train, Y_s_train, Y_b_train), (Y_test, Y_s_test, Y_b_test)

def knn_classifier(df_train, df_test=None, n_neighbors=3, feature_col=["Emb_"], feature_col_test="Emb_", label_col="Metadata_Moa", label_col_test="Metadata_Moa", test_size=0.1):
    df_train.dropna(subset=label_col, inplace=True)
    X = df_train[feature_col]
    y = df_train[label_col].to_numpy().astype('str')
    if df_test is not None:
        df_test.dropna(subset=label_col_test, inplace=True)
        X_test = df_test[feature_col_test]
        y_test = df_test[label_col_test].to_numpy().astype('str')
        X_train = X
        y_train = y
    else:
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size)
    print(f"Training samples: {len(X_train)}. Testing samples: {len(X_test)}")
    neigh = KNeighborsClassifier(n_neighbors=n_neighbors, metric = "cosine")
    neigh.fit(X_train, y_train)
    y_pred = neigh.predict(X_test).reshape(-1, 1)
    v = [accuracy_score(y_test[y_test == i], y_pred[y_test == i]) for i in sorted(np.unique(y_test))]
    v += [accuracy_score(y_test, y_pred)]
    acc = pd.DataFrame({"Class": sorted(np.unique(y_test)) + ["Total"], "# of Samples": [len(y_test[y_test == i]) for i in sorted(np.unique(y_test))] + [len(y_test)], f"Acc_Neighbor={n_neighbors}": v}, index=None)
    return acc

In [6]:
def drop_bad_columns(df, cols=None):
    if not cols:
        cols = [c for c in df.columns if "Metadata_" not in c]
    stdev = [df[c].std() for c in cols]

    cols_to_drop = []
    cols_to_drop.extend([cols[i] for i, s in enumerate(stdev) if s < 0.1 or s > 5])
    cols_to_drop.extend([c for c in cols if "Nuclei_Correlation_RWC" in c])
    cols_to_drop.extend([c for c in cols if "Nuclei_Correlation_Manders" in c])
    cols_to_drop.extend([c for c in cols if "Nuclei_Granularity_14" in c])
    cols_to_drop.extend([c for c in cols if "Nuclei_Granularity_15" in c])
    cols_to_drop.extend([c for c in cols if "Nuclei_Granularity_16" in c])

    df = df[[c for c in df.columns if c not in cols_to_drop]]
    return df, cols_to_drop

In [7]:
def generate_visualization_two(df_check_1, df_check_2, feature_cols_1, feature_cols_2, hues=None, type_viz='tsne', \
                               perplexity=30, model_1_name="Default", model_2_name="Default_2", n_color=12, plot='all'):
    if isinstance(feature_cols_1, str):
        cols_1 = [c for c in df_check_1.columns if c.startswith(feature_cols_1)]
        cols_2 = [c for c in df_check_2.columns if c.startswith(feature_cols_2)]
    else:
        cols_1 = feature_cols_1
        cols_2 = feature_cols_2
    X_1 = df_check_1[cols_1].values
    X_2 = df_check_2[cols_2].values
    X_1_length = len(X_1)
    X = np.concatenate([X_1, X_2])
    total_number_of_hue = len(hues)
    if total_number_of_hue == 1:
        hue_1 = df_check_1[hues[0]]
        hue_2 = df_check_2[hues[0]]
        Y_1_total = [hue_1.to_numpy()]
        Y_2_total = [hue_2.to_numpy()]
        # Drop np.nan in hue order
        hue_1_order = [sorted(list(set(hue_1[~pd.isna(hue_1)].to_numpy().flatten())))]
        hue_2_order = [sorted(list(set(hue_2[~pd.isna(hue_2)].to_numpy().flatten())))]
    else:
        Y_1_total = [df_check_1[hue] for hue in hues]
        Y_2_total = [df_check_2[hue] for hue in hues]        
        hue_1_orders = [sorted(list(set(labels[~pd.isna(labels)].flatten()))) for labels in Y_1_total]
        hue_2_orders = [sorted(list(set(labels[~pd.isna(labels)].flatten()))) for labels in Y_2_total]
        
    if type_viz == 'tsne':
        tsne = TSNE(perplexity=perplexity)
        X_embedded = tsne.fit_transform(X)
    elif type_viz == 'umap':
        reducer = UMAP(n_neighbors=perplexity)
        X_embedded = reducer.fit_transform(X)
    else:
        raise Exception("{type_viz} not supported, choose from tsne/umap.")
        
    X_1_embedded = X_embedded[:X_1_length]
    X_2_embedded = X_embedded[X_1_length:]
    
    custom_palette1 = sns.color_palette("bright", 15)
    custom_palette2 = sns.color_palette("dark", 15)
    sns.set_style(style='white') 
    for hue_name, label_1, label_2, hue_1, hue_2 in zip(hues, Y_1_total, Y_2_total, hue_1_order, hue_2_order):
        plt.figure()
        if plot == 'all':
            sns.scatterplot(x = X_1_embedded[:,0], y = X_1_embedded[:,1], hue=label_1, hue_order=hue_1, legend='full', palette=custom_palette1, markers = ['o'])
            sns.scatterplot(x = X_2_embedded[:,0], y = X_2_embedded[:,1], hue=label_2, hue_order=hue_2, style=label_2, legend='full', palette=custom_palette2, markers = ['^']*len(hue_2))
            plt.title(f"{model_1_name}_vs_{model_2_name}_by_{hue_name}", loc='center')
            plt.legend(bbox_to_anchor=(1, 1), loc=2, ncol=4)
        elif plot == 'one':
            sns.scatterplot(x = X_1_embedded[:,0], y = X_1_embedded[:,1], hue=label_1, hue_order=hue_1, legend='full', palette=custom_palette1, markers = ['o'])
            plt.title(f"{model_1_name}_by_{hue_name}", loc='center')
            plt.legend(bbox_to_anchor=(1, 1), loc=2, ncol=4)
        elif plot == 'two':
            sns.scatterplot(x = X_2_embedded[:,0], y = X_2_embedded[:,1], hue=label_2, hue_order=hue_2, legend='full', palette=custom_palette2, markers = ['o'])
            plt.title(f"{model_2_name}_by_{hue_name}", loc='center')
            plt.legend(bbox_to_anchor=(1, 1), loc=2, ncol=4)

In [69]:
def NS_metric(df, feature_col, on="Metadata_Plate", topk=10, is_generated=False, all_negative=True):
    np_features = df[feature_col].to_numpy()
    knn = NearestNeighbors(n_neighbors=len(df), metric = "cosine")
    knn.fit(np_features)
    neighbours_mat = knn.kneighbors(np_features, return_distance=False)
    CONTROL = 'CS(C)=O'
    smiles = df["Metadata_SMILES"].to_numpy()
    res = []
    if on == "Metadata_Plate":
        plates = df["Metadata_Plate"].to_numpy()
        for i, (s, p, rank) in tqdm(enumerate(zip(smiles, plates, neighbours_mat))):
            if is_generated and i >= len(df)/2:
                break
            if all_negative:
                not_same_index = np.argwhere((plates!=p) | ((plates==p) & (smiles!=s)))
            else:
                not_same_index = np.argwhere(plates!=p)
            true_rank = [r for r in rank if r in not_same_index]
            if is_generated:
                true_rank = [r for r in true_rank if r >= len(df)/2]
            true_label = [smiles[r] == s for r in true_rank[:topk]]
            res.append(true_label)
    elif on == "Metadata_Source":
        sources = df["Metadata_Source"].to_numpy()
        for i, (s, p, rank) in tqdm(enumerate(zip(smiles, sources, neighbours_mat))):
            if is_generated and i >= len(df)/2:
                break
            if all_negative:
                not_same_index = np.argwhere((sources!=p) | ((sources==p) & (smiles!=s)))
            else:
                not_same_index = np.argwhere(sources!=p)
            true_rank = [r for r in rank if r in not_same_index]
            if is_generated:
                true_rank = [r for r in true_rank if r >= len(df)/2]
            true_label = [smiles[r] == s for r in true_rank[:topk]]
            res.append(true_label)
    ranking = np.array(res)
    control_index = [i for i in range(len(ranking)) if smiles[i] == CONTROL]
    treated_index = [i for i in range(len(ranking)) if smiles[i] != CONTROL]
    control_ranking = np.take(ranking, control_index, axis=0)
    treated_ranking = np.take(ranking, treated_index, axis=0)
    treated_smiles = np.take(smiles, treated_index, axis=0)
    def calc_acc(ranking, topk):
        acc_stats = []
        correct_smiles = []
        for i in range(1, topk+1):
            acc = 0
            cnt = 0
            correct_smiles.append(defaultdict(lambda: 0))
            for j in range(len(ranking)):
                if True in ranking[j,:i]:
                    acc += 1
                    correct_smiles[i-1][treated_smiles[j]] += 1
            acc_stats.append(acc/len(ranking))
        return (acc_stats, correct_smiles)
    
    return calc_acc(treated_ranking, topk)

In [9]:
def NS_metric_across(subject, candidate, feature_col, on="Metadata_Plate", topk=10, all_negative=False, return_smiles=False):
    subject_features = subject[feature_col].to_numpy()
    candidate_features = candidate[feature_col].to_numpy()
    labels = candidate["Metadata_SMILES"].to_list()
    total_feature = np.concatenate([subject_features, candidate_features], axis=0)
    candidate_index_to_class = dict(zip(range(len(candidate_features)), labels))
    nneighbor = NearestNeighbors(n_neighbors=len(candidate_features), metric = "cosine")
    nneighbor.fit(candidate_features)
    subject_mat = nneighbor.kneighbors(subject_features, return_distance=False)
    subject_smiles = subject["Metadata_SMILES"]
    CONTROL = 'CS(C)=O'
    prediction = []
    if on == "Metadata_Plate":
        plates = subject["Metadata_Plate"].to_numpy()
        for i, (s, p, rank) in tqdm(enumerate(zip(subject_smiles, plates, subject_mat))):
            exclude_index = np.argwhere(plates!=p)
            if all_negative:
                exclude_index = np.argwhere((plates!=p) | ((plates==p) & (subject_smiles!=s)))
            true_rank = [candidate_index_to_class[r] for r in rank if r in exclude_index]
            TF_labels = [smi == s for smi in true_rank[:topk]]
            prediction.append(TF_labels)
    elif on == "Metadata_Batch":
        batches = subject["Metadata_Batch"].to_numpy()
        for i, (s, p, rank) in tqdm(enumerate(zip(subject_smiles, batches, subject_mat))):
            exclude_index = np.argwhere(batches!=p)
            if all_negative:
                exclude_index = np.argwhere((batches!=p) | ((batches==p) & (subject_smiles!=s)))
            true_rank = [candidate_index_to_class[r] for r in rank if r in exclude_index]
            TF_labels = [smi == s for smi in true_rank[:topk]]
            prediction.append(TF_labels)
    elif on == "Metadata_Source":
        sources = subject["Metadata_Source"].to_numpy()
        for i, (s, p, rank) in tqdm(enumerate(zip(subject_smiles, sources, subject_mat))):
            exclude_index = np.argwhere(sources!=p)
            if all_negative:
                exclude_index = np.argwhere((sources!=p) | ((sources==p) & (subject_smiles!=s)))
            true_rank = [candidate_index_to_class[r] for r in rank if r in exclude_index]
            TF_labels = [smi == s for smi in true_rank[:topk]]
            prediction.append(TF_labels)
    prediction = np.array(prediction)
    subject_smi = subject["Metadata_SMILES"].to_numpy()
    control_index = [i for i in range(len(subject_smi)) if subject_smi[i] == CONTROL]
    treated_index = [i for i in range(len(subject_smi)) if subject_smi[i] != CONTROL]
    control_ranking = np.take(prediction, control_index, axis=0)
    treated_ranking = np.take(prediction, treated_index, axis=0)
    treated_smiles = np.take(subject_smi, treated_index, axis=0)
    def calc_acc(ranking, topk, return_smiles=return_smiles):
        acc_stats = []
        correct_smiles = []
        for i in range(1, topk+1):
            acc = 0
            cnt = 0
            correct_smiles.append(defaultdict(lambda: 0))
            for j in range(len(ranking)):
                if True in ranking[j,:i]:
                    acc += 1
                    correct_smiles[i-1][treated_smiles[j]] += 1
            acc_stats.append(acc/len(ranking))
        if return_smiles:
            return (acc_stats, correct_smiles)
        else:
            return acc_stats
    
    return calc_acc(treated_ranking, topk)

In [None]:
def generate_visualization(df_check, feature_cols, hues= None, 
                           type_viz='tsne', perplexity=30, model_name="Default", n_color=12, legend=True, show_bg=False,
                          save_fig=False):
    if isinstance(feature_cols, str):
        cols = [c for c in df_check.columns if c.startswith(feature_cols)]
    else:
        cols = feature_cols
    X = df_check[cols].values
    total_number_of_hue = len(hues)
    if total_number_of_hue == 1:
        Y_total = [df_check[hues[0]].to_numpy()]
    else:
        Y_total = [df_check[hue].to_numpy() for hue in hues]
    
    if type_viz == 'tsne':
        tsne = TSNE(perplexity=perplexity)
        X_embedded = tsne.fit_transform(X)
    elif type_viz == 'umap':
        reducer = UMAP(n_neighbors=perplexity)
        X_embedded = reducer.fit_transform(X)
    else:
        raise Exception("{type_viz} not supported, choose from tsne/umap.")
        
    custom_palette = sns.color_palette(cc.glasbey, n_color)
    sns.set_style(style='white') 
    for hue_name, hue in zip(hues, Y_total):
        plt.figure()
        hue = [hue[i] if not pd.isnull(Y_total[0][i]) else np.nan for i in range(len(hue))]
        bg_index = [x for x in range(len(hue)) if hue[x] == 'bg']
        value_index = [x for x in range(len(hue)) if x not in bg_index]
        if legend:
            sns.scatterplot(x = X_embedded[value_index,0], y = X_embedded[value_index,1], hue=[hue[i] for i in value_index], palette=custom_palette, zorder=5)
            if bg_index and show_bg:
                sns.scatterplot(x = X_embedded[bg_index,0], y = X_embedded[bg_index,1], hue=[hue[i] for i in bg_index], color="lightgray", zorder=0)
        else:
            sns.scatterplot(x = X_embedded[value_index,0], y = X_embedded[value_index,1], hue=[hue[i] for i in value_index], palette=custom_palette, legend=None, zorder=5)
            if bg_index and show_bg:
                sns.scatterplot(x = X_embedded[bg_index,0], y = X_embedded[bg_index,1], hue=[hue[i] for i in bg_index], color="lightgray", legend=None, zorder=0)
        plt.xticks([])
        plt.yticks([])
        # plt.title(f"{model_name}_by_{hue_name}", loc='center')
        plt.legend(bbox_to_anchor=(1, 1), loc=2, ncol=1)
        if save_fig:
            plt.savefig(f"vis_{hue_name}.png", dpi=300, bbox_inches='tight')

In [None]:
data = pd.read_csv("embeddings/target2.centered.csv", low_memory=False)

In [23]:
data.to_csv("embeddings/target2.centered.csv", index=False)

In [19]:
metadata_csv = pd.read_csv("datasets/treated_moa_target2/metadata_test")

In [12]:
cp_cols = [c for c in data.columns if not c.startswith("Metadata_") and not c.startswith("micon_") and not c.endswith("_path")]

In [10]:
cp_fname = "embeddings/pos_control.centered.parquet"
model_fname = "embeddings/pos_control_raw_embeddings_supcon_freeze_img_train_14000.pkl"

pos_control = read_file_embeddings(cp_fname, model_fname, f_dim=1000, feature_cols="micon_")
cp_cols = [c for c in pos_control.columns if not c.startswith("Metadata_") and not c.startswith("micon_") and not c.endswith("_path")]
micon_cols = [c for c in pos_control.columns if c.startswith("micon_")]
pos_control = average_wells(pos_control, feature_cols="micon_")

# pos_control_processed = plate_wise_spherize_and_normailize(pos_control, plate_col="Metadata_Batch", feature_cols=cp_cols, control_only=True)




35136


35136it [05:20, 109.58it/s]


 94%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍             | 68/72 [02:27<00:07,  1.85s/it]

No control samples found. Fall back to full normailization


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 72/72 [02:35<00:00,  2.15s/it]
 94%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍             | 68/72 [01:51<00:05,  1.48s/it]

No control samples found. Fall back to full normailization


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 72/72 [01:57<00:00,  1.63s/it]


In [51]:
pos_control_processed_source = plate_wise_spherize_and_normailize(pos_control, plate_col="Metadata_Source", feature_cols=cp_cols, control_only=True)
pos_control_processed_source = plate_wise_spherize_and_normailize(pos_control_processed_source, plate_col="Metadata_Source", feature_cols=micon_cols, control_only=True)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:22<00:00,  3.69s/it]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:14<00:00,  2.35s/it]


In [14]:
cpfname = "embeddings/target2.centered.csv"
fname = "check_embeddings/treated_moa_target2/new_micon_divide255_freeze_img_resnet101_with_generate_with_aux_18000_original.pkl"

target2_freeze_with_all = read_file_embeddings(cpfname, fname, f_dim=1000, feature_cols="micon_", is_moa=False)
target2_freeze_with_all_avg = average_wells(target2_freeze_with_all, feature_cols="micon_")

62162


62162it [54:43, 18.93it/s]


In [144]:
cp_cols = [c for c in target2_freeze_with_all_avg.columns if not c.startswith("Metadata_") and not c.startswith("micon_") and not c.endswith("_path")]
micon_cols = [c for c in target2_freeze_with_all_avg.columns if c.startswith("micon_")]
target2_batch_processed = plate_wise_spherize_and_normailize(target2_freeze_with_all_avg, plate_col="Metadata_Batch", feature_cols=cp_cols, control_only=True)
target2_batch_processed = plate_wise_spherize_and_normailize(target2_batch_processed, plate_col="Metadata_Batch", feature_cols=micon_cols, control_only=True)
target2_source_processed = plate_wise_spherize_and_normailize(target2_freeze_with_all_avg, plate_col="Metadata_Source", feature_cols=cp_cols, control_only=True)
target2_source_processed = plate_wise_spherize_and_normailize(target2_source_processed, plate_col="Metadata_Source", feature_cols=micon_cols, control_only=True)
# pos_control_processed = plate_wise_spherize_and_normailize(target2_freeze_with_all, plate_col="Metadata_Batch", feature_cols=cp_cols, control_only=True)
# pos_control_processed = plate_wise_spherize_and_normailize(pos_control_processed, plate_col="Metadata_Batch", feature_cols=micon_cols, control_only=True)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 101/101 [02:04<00:00,  1.23s/it]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 101/101 [01:46<00:00,  1.06s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:20<00:00,  2.08s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████