In [10]:
import sys
from pathlib import Path
import os

ROOT_DIR = Path(os.getcwd()).resolve().parent.parent
if (ROOT_DIR / "networks").exists():
    sys.path.append(str(ROOT_DIR))
else:
    raise FileNotFoundError("Root directory does not contain expected structure. Please adjust ROOT_DIR.")

In [11]:
## Standard libraries
import os
import pandas as pd
from tabulate import tabulate

## Seeds
import random
import numpy as np

## PyTorch
import torch
import torch.utils.data as data

# Custom libraries
from networks.SimpleMLPs import MLPsumV2
from src.dataloader_pickles import DataloaderEvalV5
import utils
from pytorch_metric_learning import losses, distances
from tqdm import tqdm

## UMAP libraries
import matplotlib.pyplot as plt
import umap.plot
import matplotlib as mpl
import copy
import plotly.express as px

from sklearn.metrics.pairwise import cosine_similarity
import seaborn as sns


NUM_WORKERS = 0
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

# Set random seed for reproducibility
manualSeed = 42
random.seed(manualSeed)
torch.manual_seed(manualSeed)
np.random.seed(manualSeed)


In [2]:
#%%
# Load model
save_name_extension = 'model_bestval_simpleMLP_V1'
model_name = save_name_extension
print('Loading:', model_name)

input_dim = 1324
kFilters = 1/2
latent_dim = 2048
output_dim = 2048
model = MLPsumV2(input_dim=input_dim, latent_dim=latent_dim, output_dim=output_dim,
                 k=kFilters, dropout=0, cell_layers=1,
                 proj_layers=2, reduction='sum')
if torch.cuda.is_available():
    model.cuda()

path = r'/Users/rdijk/PycharmProjects/featureAggregation/wandb/latest-run/files'

models = os.listdir(path)
fullpath = os.path.join(path, model_name)
if 'ckpt' in model_name:
    model.load_state_dict(torch.load(fullpath)['model_state_dict'])
else:
    model.load_state_dict(torch.load(fullpath))
model.eval()

#%% Select specific plate and compound
# dataset = 'Stain4'
# plates = ['200922_015124-V_FS']
dataset = 'Stain2'
plates = ['BR00112197binned_FS']

Loading: model_bestval_simpleMLP_V1


In [3]:
#%%
# Load all data
rootDir = fr'/Users/rdijk/PycharmProjects/featureAggregation/datasets/{dataset}'
metadata = pd.read_csv('/inputs/cpg0001_metadata/JUMP-MOA_compound_platemap_with_metadata.csv', index_col=False)
plateDirs = [x[0] for x in os.walk(rootDir)][1:]

plateDirs = [x for x in plateDirs if any(substr in x for substr in plates)]

platestring = plateDirs[0].split('_')[-2]
print('Calculating results for: ' + platestring)
metadata = utils.addDataPathsToMetadata(rootDir, metadata, plateDirs)

# Filter the data and create numerical labels
df_prep = utils.filterData(metadata, '', encode='pert_iname')
# Add all data to one DF
Total, _ = utils.train_val_split(df_prep, 1.0, sort=False)
Total = Total.sort_values(by='Metadata_labels')

# Select only validation compounds
#valTotal = Total.sort_values(by='Metadata_labels').iloc[288:, :].reset_index(drop=True)

valset = DataloaderEvalV5(Total)
full_loader = data.DataLoader(valset, batch_size=1, shuffle=False, # 96
                         drop_last=False, pin_memory=False, num_workers=NUM_WORKERS)


Calculating results for: BR00112197binned


In [4]:
# Calculate profiles
MLP_profiles = pd.DataFrame()
with torch.no_grad():
    for idx, (points, labels) in enumerate(tqdm(full_loader)):
        feats, _ = model(points) 
        c1 = pd.concat([pd.DataFrame(feats), pd.Series(labels)], axis=1)
        MLP_profiles = pd.concat([MLP_profiles, c1])
MLP_profiles.columns = [f"f{x}" for x in range(MLP_profiles.shape[1] - 1)] + ['Metadata_labels']
MLP_profiles = MLP_profiles.reset_index(drop=True)

100%|█████████████████████████████████████████| 384/384 [00:26<00:00, 14.57it/s]


In [5]:
metadata.sort_values(by='moa').reset_index(drop=True).iloc[384-48:].reset_index(drop=True)
    
numbers = list(metadata[metadata.moa.isin(['LXR agonist', 'IGF-1 inhibitor'])].index)
#metadata[np.logical_or(metadata.moa.isin(['LXR agonist', 'IGF-1 inhibitor']), metadata.pert_type=='control')]
metadata[metadata.moa.isin(['LXR agonist', 'IGF-1 inhibitor'])]

Unnamed: 0,well_position,broad_sample,solvent,InChIKey,pert_iname,pubchem_cid,moa,smiles,pert_type,control_type,plate1,Metadata_labels
28,B05,BRD-K33818169-003-04-6,DMSO,NAXSRXHZFIBFMI-UHFFFAOYSA-N,GW-3965,447905.0,LXR agonist,OC(=O)Cc1cccc(OCCCN(CC(c2ccccc2)c2ccccc2)Cc2cc...,trt,,/Users/rdijk/PycharmProjects/featureAggregatio...,33
35,B12,BRD-K33818169-003-04-6,DMSO,NAXSRXHZFIBFMI-UHFFFAOYSA-N,GW-3965,447905.0,LXR agonist,OC(=O)Cc1cccc(OCCCN(CC(c2ccccc2)c2ccccc2)Cc2cc...,trt,,/Users/rdijk/PycharmProjects/featureAggregatio...,33
74,D03,Compound5,DMSO,,Compound5,,IGF-1 inhibitor,,trt,,/Users/rdijk/PycharmProjects/featureAggregatio...,22
86,D15,BRD-K23383398-001-09-9,DMSO,SGIWFELWJPNFDH-UHFFFAOYSA-N,T-0901317,447912.0,LXR agonist,OC(c1ccc(cc1)N(CC(F)(F)F)S(=O)(=O)c1ccccc1)(C(...,trt,,/Users/rdijk/PycharmProjects/featureAggregatio...,59
131,F12,Compound5,DMSO,,Compound5,,IGF-1 inhibitor,,trt,,/Users/rdijk/PycharmProjects/featureAggregatio...,22
164,G21,BRD-K33818169-003-04-6,DMSO,NAXSRXHZFIBFMI-UHFFFAOYSA-N,GW-3965,447905.0,LXR agonist,OC(=O)Cc1cccc(OCCCN(CC(c2ccccc2)c2ccccc2)Cc2cc...,trt,,/Users/rdijk/PycharmProjects/featureAggregatio...,33
175,H08,BRD-K24696047-001-02-3,DMSO,AECDBHGVIIRMOI-GRGXKFILSA-N,NVP-AEW541,,IGF-1 inhibitor,Nc1ncnc2n(cc(-c3cccc(OCc4ccccc4)c3)c12)[C@@H]1...,trt,,/Users/rdijk/PycharmProjects/featureAggregatio...,42
212,I21,BRD-K24696047-001-02-3,DMSO,AECDBHGVIIRMOI-GRGXKFILSA-N,NVP-AEW541,,IGF-1 inhibitor,Nc1ncnc2n(cc(-c3cccc(OCc4ccccc4)c3)c12)[C@@H]1...,trt,,/Users/rdijk/PycharmProjects/featureAggregatio...,42
214,I23,Compound5,DMSO,,Compound5,,IGF-1 inhibitor,,trt,,/Users/rdijk/PycharmProjects/featureAggregatio...,22
237,J22,BRD-K23383398-001-09-9,DMSO,SGIWFELWJPNFDH-UHFFFAOYSA-N,T-0901317,447912.0,LXR agonist,OC(c1ccc(cc1)N(CC(F)(F)F)S(=O)(=O)c1ccccc1)(C(...,trt,,/Users/rdijk/PycharmProjects/featureAggregatio...,59


In [6]:
#%% Choose which wells to calculate saliencies for

# For plotting cell saliency in UMAP plot
themes = ['fire', 'viridis', 'darkgreen', 'darkblue', 'inferno', 'darkred' ]
point_list = [34, 37, 9] 
          #poor BM: sirolimus (0.31), skepinone-l (0.24), 
          #poor BM: purmorphamine (0.30)
          #good BM: valrubicin (0.9), romidepsin (0.95)
WELLS = list(metadata.sort_values(by='Metadata_labels').reset_index(drop=True).iloc[384-48:].reset_index(drop=True).loc[point_list, 'well_position'])
PERTURBATIONS = list(metadata.sort_values(by='Metadata_labels').reset_index(drop=True).iloc[384-48:].reset_index(drop=True).loc[point_list, 'pert_iname'])

# point_list = [numbers[1]]+[numbers[4]]+[numbers[9]]+[numbers[7]]+[88,90,102,103]
# WELLS = list(metadata.loc[point_list, 'well_position'])
# PERTURBATIONS = list(metadata.loc[point_list, 'pert_iname'])
print(WELLS, PERTURBATIONS)

In [7]:
metadata_indices_used = []
for i in range(len(WELLS)):
    a = metadata.sort_values(by='Metadata_labels').reset_index(drop=True)[metadata['well_position'] == WELLS[i]].index
    metadata_indices_used.append(a)

In [8]:
mask_list = []
dataloader_list = []
for well in WELLS:
    # Load single well data
    mask = df_prep.well_position == well
    df_prep_masked = df_prep.loc[mask]
    # Add all data to one DF
    singleWell, _ = utils.train_val_split(df_prep_masked, 1.0, sort=False)

    smallset = DataloaderEvalV5(singleWell)
    dataloader = data.DataLoader(smallset, batch_size=1, shuffle=False,
                             drop_last=False, pin_memory=False, num_workers=NUM_WORKERS)
    
    mask_list.append(mask)
    dataloader_list.append(dataloader)

In [9]:
df_prep_masked.head() # Show last one in for loop

Unnamed: 0,well_position,broad_sample,solvent,InChIKey,pert_iname,pubchem_cid,moa,smiles,pert_type,control_type,plate1,Metadata_labels
295,M08,BRD-K73397362-001-06-6,DMSO,FYBHCRQFSFYWPY-UHFFFAOYSA-N,purmorphamine,5284329.0,smoothened receptor agonist,C1CCC(CC1)n1cnc2c(Nc3ccc(cc3)N3CCOCC3)nc(Oc3cc...,trt,,/Users/rdijk/PycharmProjects/featureAggregatio...,81


## Saliency Layer Wise Relevance propagation

In [10]:
from abc import ABC

import torch.nn as nn
from torch import reshape

def LRP_individual(model, X, device):
    # Get the list of layers of the network
    layers = [module for module in model.modules() if not isinstance(module, torch.nn.Sequential)][1:]
    # Propagate the input
    L = len(layers)
    A = [X] + [X] * L # Create a list to store the activation produced by each layer
    for layer in range(L):
        # After the 4th layer, sum the activations
        if layer == 3:
            A[layer] = torch.sum(A[layer], 1)
#         print(layer)
#         print(layers[layer])
#         print(A[layer].shape)
        A[layer + 1] = layers[layer].forward(A[layer])
    # Get the relevance of the last layer using the highest classification score of the top layer
    T = A[-1].cpu().detach().numpy().tolist()[0]
    index = T.index(max(T))
    T = np.abs(np.array(T)) * 0
    T[index] = 1
    T = torch.FloatTensor(T)
    # Create the list of relevances with (L + 1) elements and assign the value of the last one 
    R = [None] * L + [(A[-1].cpu() * T).data + 1e-6]
       # LRP_individual function continuation...
    # Propagation procedure from the top-layer towards the lower layers
    for layer in range(0, L)[::-1]:
        if isinstance(layers[layer], torch.nn.Conv2d) or isinstance(layers[layer], torch.nn.Conv3d) \
                or isinstance(layers[layer], torch.nn.AvgPool2d) or isinstance(layers[layer], torch.nn.Linear):

            # Specifies the rho function that will be applied to the weights of the layer
            if 0 < layer <= 13:  # Gamma rule (LRP-gamma)
                rho = lambda p: p + 0.25 * p.clamp(min=0)
            else:  # Basic rule (LRP-0)
                rho = lambda p: p

            A[layer] = A[layer].data.requires_grad_(True)
            # Step 1: Transform the weights of the layer and executes a forward pass
            z = newlayer(layers[layer], rho).forward(A[layer]) + 1e-9
            # Step 2: Element-wise division between the relevance of the next layer and the denominator
            s = (R[layer + 1].to(device) / z).data
            # Step 3: Calculate the gradient and multiply it by the activation layer
            (z * s).sum().backward()
            c = A[layer].grad  
            R[layer] = (A[layer] * c).cpu().data  
            # Before going back, reshape the relevances of layers 4 and 17 back to their original form  
            if layer == 4:
                R[layer] = reshape(R[layer], (R[layer].shape[0], 16, int(R[layer].shape[1] / 16), R[layer].shape[2], R[layer].shape[3]))
        else:
            R[layer] = R[layer + 1]
    
    # Return the relevance of the input layer
    return R[0]


def newlayer(layer, g):
    """Clone a layer and pass its parameters through the function g."""
    layer = copy.deepcopy(layer)
    layer.weight = torch.nn.Parameter(g(layer.weight))
    layer.bias = torch.nn.Parameter(g(layer.bias))
    return layer

In [11]:
saliency_cells_LWR = []
for loader, mask in zip(dataloader_list, mask_list):
    points, clabel = next(iter(loader))
#    print(points.shape)
    LWR = LRP_individual(model, points, device)
    LWR = torch.sum(LWR, 2)
    saliency_cells_LWR.append(np.squeeze(np.array(LWR.detach())))
    assert LWR.shape[1] == points.shape[1]

## Saliency V0 - L1 norm of activations per cell

In [12]:
saliency_cells_V0 = []
for loader, mask in zip(dataloader_list, mask_list):
    points, clabel = next(iter(loader))
    print(points.shape)
    print(clabel)
    
    _, profile = model(points)

    saliency_V0 = torch.sum(abs(profile), dim=2)

    saliency_cells_V0.append(np.squeeze(np.array(saliency_V0.detach())))
    assert saliency_V0.shape[1] == profile.shape[1]

torch.Size([1, 2138, 1324])
tensor([87], dtype=torch.int16)
torch.Size([1, 2872, 1324])
tensor([88], dtype=torch.int16)
torch.Size([1, 3147, 1324])
tensor([81], dtype=torch.int16)


## Saliency V1 - Gradient based saliency (L1 norm per cell)

In [13]:
commonfeats_df = pd.read_csv('/Users/rdijk/Documents/Data/RawData/CommonFeatureNames.csv', index_col=False)
Y = 5
loss_func = losses.SupConLoss(distance=distances.CosineSimilarity())
#%% Saliency_V1 Calculate gradient based saliencies - L1 norm of gradients per cell 
saliency_cells_V1 = []
saliency_feats_V1 = []
topY_feature_list_of_list = []
for loader, mask in zip(dataloader_list, mask_list):
    points, clabel = next(iter(loader))
    print(points.shape)
    print(clabel)
    
    MLP_profiles_masked = MLP_profiles[~mask] # remove well profile that we are analysing
    points.requires_grad_()
    profile, _ = model(points)

    L = loss_func(torch.cat([torch.tensor(MLP_profiles_masked.iloc[:, :-1].values), profile]), 
                  torch.cat([torch.tensor(MLP_profiles_masked.iloc[:, -1].values), clabel]))
    
    features_loss = L.backward()
    saliency_V1 = torch.sum(points.grad.data.abs(), dim=2)
    saliency_features = torch.sum(points.grad.data.abs(), dim=1)
    
    # Get top Y features per cell     
    top = np.argpartition(torch.squeeze(points.grad.data.abs()), -Y, axis=1)[:, -Y:]
    topY_feature_list = []
    print(top.shape)
    for k in range(points.shape[1]):
        featurenames = commonfeats_df.iloc[top[k]].values
        topY_feature_list.append(','.join([x[0] for x in featurenames]))
    topY_feature_list_of_list.append(topY_feature_list)
    
    
    saliency_cells_V1.append(np.squeeze(np.array(saliency_V1)))
    saliency_feats_V1.append(np.squeeze(np.array(saliency_features)))
    assert saliency_V1.shape[1] == points.shape[1]
    assert saliency_features.shape[1] == points.shape[2]

torch.Size([1, 2138, 1324])
tensor([87], dtype=torch.int16)
torch.Size([2138, 5])
torch.Size([1, 2872, 1324])
tensor([88], dtype=torch.int16)
torch.Size([2872, 5])
torch.Size([1, 3147, 1324])
tensor([81], dtype=torch.int16)
torch.Size([3147, 5])


In [14]:
# plot the heatmap
print(points.grad.data.abs()[0,...].shape)
FN = pd.read_csv('/Users/rdijk/Documents/Data/RawData/CommonFeatureNames.csv', index_col=False)

if False:
    # Take Cells/Cytoplasm/Nuclei labels
    feature_names = [x.split('.')[0] for x in FN.FeatureNames]
else: 
    # Take Intensity/RadialDistribution/Correlation, ... labels
    feature_names = [x.split('.')[1].split('_')[1] for x in FN.FeatureNames]
fn = pd.DataFrame({'featurenames': feature_names})
lut = dict(zip(fn.featurenames.unique(), "rbgcmcyckw"))
row_colors = fn.featurenames.map(lut)
fn.featurenames.unique()


torch.Size([3147, 1324])


array(['Intensity', 'RadialDistribution', 'Correlation', 'Parent',
       'Granularity', 'Location', 'AreaShape', 'Neighbors', 'Number',
       'Children'], dtype=object)

In [15]:
# sns.clustermap(points.grad.data.abs()[0,...].t(), 
#                row_colors=row_colors.to_numpy())
# from matplotlib.patches import Patch

# handles = [Patch(facecolor=lut[name]) for name in lut]
# plt.legend(handles, lut, title='Feature type',
#            bbox_to_anchor=(1, 1), bbox_transform=plt.gcf().transFigure, loc='upper right')

In [16]:
# from scipy.spatial import distance
# from scipy.cluster import hierarchy
# row_linkage = hierarchy.linkage(
#     distance.pdist(points.grad.data.abs()[0,...].t()), method='average', optimal_ordering=True)
# cluster_numbers = hierarchy.fcluster(row_linkage, 2, criterion='maxclust')
# a = pd.DataFrame({'feature': FN.values[:,0], 'cluster': cluster_numbers})
# a = a.groupby('cluster')
# for key, item in a:
#     print(a.get_group(key).to_markdown(), "\n\n")


In [17]:
# Trace the features saliency back to their names 
saliency_feats_dfs = []
# Return top X features
X = 0
for idx in range(len(saliency_feats_V1)):
    top_idx = np.argpartition(saliency_feats_V1[idx], -X)[-X:]
    newDF = pd.DataFrame()
    newDF['FeatureNames'] = commonfeats_df.iloc[top_idx]
    newDF['Saliency'] = saliency_feats_V1[idx][top_idx]
    newDF = newDF.sort_values(by='Saliency', ascending=False)
    saliency_feats_dfs.append(newDF)
    if X == 0:
        print(newDF.shape)
        continue
    else:
        print(tabulate(newDF, headers='keys', tablefmt='github', showindex=False))

(1324, 2)
(1324, 2)
(1324, 2)


## Saliency V2 - Distance based saliency (aggregated representations in loss space)

In [18]:
# Calculate distance based saliencies - distance of single cell representations to aggregated representation
all_single_cell_profiles = []
all_final_profiles = []
saliency_cells_V2 = []
for loader in dataloader_list:
    points, clabel = next(iter(loader))
    print(points.shape)
    print(clabel)
    
    single_cell_profiles = torch.tensor([])
    final_profile, _ = model(points)

    for idx in tqdm(range(points.shape[1])):
        single_cell = points[:, [idx], :]
        sc_profile, _ = model(single_cell)
        single_cell_profiles = torch.cat([single_cell_profiles, sc_profile])
    single_cell_profiles = single_cell_profiles.detach().numpy()
    final_profile = final_profile.detach().numpy()
    cos = cosine_similarity(single_cell_profiles, final_profile)[:, 0]
    saliency_cells_V2.append(cos)
    all_single_cell_profiles.append(single_cell_profiles)
    all_final_profiles.append(final_profile)
    assert len(cos) == points.shape[1]

torch.Size([1, 2138, 1324])
tensor([87], dtype=torch.int16)


100%|██████████████████████████████████████| 2138/2138 [00:04<00:00, 469.27it/s]


torch.Size([1, 2872, 1324])
tensor([88], dtype=torch.int16)


100%|██████████████████████████████████████| 2872/2872 [00:09<00:00, 296.24it/s]


torch.Size([1, 3147, 1324])
tensor([81], dtype=torch.int16)


100%|██████████████████████████████████████| 3147/3147 [00:16<00:00, 191.72it/s]


## Saliency V3 - Leave one cell out saliency (loss influence per cell)

In [19]:
# all_losses = []
# for loader, mask in zip(dataloader_list, mask_list):
#     points, clabel = next(iter(loader))
#     print(points.shape)
#     print(clabel)

#     # Leave one cell out
#     loco_losses = []
#     MLP_profiles_masked = MLP_profiles[~mask] # remove well profile that we are analysing
#     for idx in tqdm(range(points.shape[1])):
#         masked_points = points[:, torch.arange(points.size(1))!=idx, :]
#         masked_profile, _ = model(masked_points)

#         L = loss_func(torch.cat([torch.tensor(MLP_profiles_masked.iloc[:, :-1].values), masked_profile]), 
#                       torch.cat([torch.tensor(MLP_profiles_masked.iloc[:, -1].values), clabel]))
#         loco_losses.append(L.item())
#     all_losses.append(loco_losses)
#     assert len(loco_losses) == points.shape[1]

In [20]:
# saliency_cells_V3 = []
# for LOSSES in all_losses:
#     # Convery to numpy array
#     loco_losses_array = np.array(LOSSES)
#     # Calculate saliency for most noisy cells, that, when removed, decrease the loss the most.
#     # Also normalize the saliency values
#     saliency_min = np.min(loco_losses_array) / loco_losses_array

#     # lowest loss corresponds to noisiest cell, thus invert for most salient cell
#     saliency_cells_V3.append(abs(1-saliency_min))

In [21]:
# saliency_cells_V3[0]

# UMAP visualizations

In [22]:
# Get all points in one array
all_points_list, all_cellcounts, all_labels, all_mean_profiles = [], [], [], []
for loader in dataloader_list:
    points, clabel = next(iter(loader))
    all_points_list.append(points[0,...])
    all_cellcounts.append(points.shape[1])
    all_labels.append(clabel)
    all_mean_profiles.append(points[0,...].mean(0).detach().numpy())
all_points = np.concatenate(all_points_list)
all_labels = np.concatenate(all_labels)
all_cellcounts = np.cumsum(all_cellcounts)
print(all_points.shape)
print(all_labels)

(8157, 1324)
[87 88 81]


In [None]:
# Configure UMAP
reducer = umap.UMAP(random_state=42, metric='euclidean', n_epochs=700, repulsion_strength=5,
                   local_connectivity=3, n_neighbors=100)
embedding = reducer.fit(all_points)
embedding.embedding_.shape

In [None]:
# Plot all saliency types next to each other
fig, ax = plt.subplots(3, 1, dpi=300)
for i, theme in zip(range(3), themes):
    c_embedding = copy.deepcopy(embedding)
    if i==0:
        c_embedding.embedding_ = embedding.embedding_[:all_cellcounts[i], ...]
    else:
        c_embedding.embedding_ = embedding.embedding_[all_cellcounts[i-1]:all_cellcounts[i], ...]
    values = saliency_cells_V1[i]
    values = (values-np.min(values)) / (np.max(values)-np.min(values)) # normalize
    
    umap.plot.points(c_embedding, values=values, theme=theme, ax=ax[0])
    #psm = plt.pcolormesh([values, values], cmap=mpl.cm.get_cmap(theme))
    #cb = plt.colorbar(psm, ax=ax)

for i, theme in zip(range(3), themes):
    c_embedding = copy.deepcopy(embedding)
    if i==0:
        c_embedding.embedding_ = embedding.embedding_[:all_cellcounts[i], ...]
    else:
        c_embedding.embedding_ = embedding.embedding_[all_cellcounts[i-1]:all_cellcounts[i], ...]
    values = saliency_cells_V2[i]
    values = (values-np.min(values)) / (np.max(values)-np.min(values)) # normalize
    
    umap.plot.points(c_embedding, values=values, theme=theme, ax=ax[1])
    #psm = plt.pcolormesh([values, values], cmap=mpl.cm.get_cmap(theme))
    #cb = plt.colorbar(psm, ax=ax)

# for i, theme in zip(range(3), themes):
#     c_embedding = copy.deepcopy(embedding)
#     if i==0:
#         c_embedding.embedding_ = embedding.embedding_[:all_cellcounts[i], ...]
#     else:
#         c_embedding.embedding_ = embedding.embedding_[all_cellcounts[i-1]:all_cellcounts[i], ...]

#     values = saliency_cells_V3[i]
#     values = (values-np.min(values)) / (np.max(values)-np.min(values)) # normalize
    
#     umap.plot.points(c_embedding, values=values, theme=theme, ax=ax[2])
#     #psm = plt.pcolormesh([values, values], cmap=mpl.cm.get_cmap(theme))
#     #cb = plt.colorbar(psm, ax=ax)
    
plt.show()

In [None]:
# Plot all saliency types next to each other
fig, ax = plt.subplots(3, 1, dpi=300)
for i, theme in zip(range(3), themes):
    if i == 0 or i == 2:
        continue
    c_embedding = copy.deepcopy(embedding)
    if i==0:
        c_embedding.embedding_ = embedding.embedding_[:all_cellcounts[i], ...]
    else:
        c_embedding.embedding_ = embedding.embedding_[all_cellcounts[i-1]:all_cellcounts[i], ...]
    values = saliency_cells_V1[i]
    values = (values-np.min(values)) / (np.max(values)-np.min(values)) # normalize
    
    umap.plot.points(c_embedding, values=values, theme=theme, ax=ax[0])
    #psm = plt.pcolormesh([values, values], cmap=mpl.cm.get_cmap(theme))
    #cb = plt.colorbar(psm, ax=ax)

for i, theme in zip(range(3), themes):
    if i == 0 or i == 2:
        continue
    c_embedding = copy.deepcopy(embedding)
    if i==0:
        c_embedding.embedding_ = embedding.embedding_[:all_cellcounts[i], ...]
    else:
        c_embedding.embedding_ = embedding.embedding_[all_cellcounts[i-1]:all_cellcounts[i], ...]
    values = saliency_cells_V2[i]
    values = (values-np.min(values)) / (np.max(values)-np.min(values)) # normalize
    
    umap.plot.points(c_embedding, values=values, theme=theme, ax=ax[1])
    #psm = plt.pcolormesh([values, values], cmap=mpl.cm.get_cmap(theme))
    #cb = plt.colorbar(psm, ax=ax)

# for i, theme in zip(range(3), themes):
#     if i == 0 or i == 2:
#         continue
#     c_embedding = copy.deepcopy(embedding)
#     if i==0:
#         c_embedding.embedding_ = embedding.embedding_[:all_cellcounts[i], ...]
#     else:
#         c_embedding.embedding_ = embedding.embedding_[all_cellcounts[i-1]:all_cellcounts[i], ...]

#     values = saliency_cells_V3[i]
#     values = (values-np.min(values)) / (np.max(values)-np.min(values)) # normalize
    
#     umap.plot.points(c_embedding, values=values, theme=theme, ax=ax[2])
#     #psm = plt.pcolormesh([values, values], cmap=mpl.cm.get_cmap(theme))
#     #cb = plt.colorbar(psm, ax=ax)
    
plt.show()

In [None]:
# Plot all saliency types next to each other
fig, ax = plt.subplots(3, 1, dpi=300)
for i, theme in zip(range(3), themes):
    if i == 1 or i == 2:
        continue
    c_embedding = copy.deepcopy(embedding)
    if i==0:
        c_embedding.embedding_ = embedding.embedding_[:all_cellcounts[i], ...]
    else:
        c_embedding.embedding_ = embedding.embedding_[all_cellcounts[i-1]:all_cellcounts[i], ...]
    values = saliency_cells_V1[i]
    values = (values-np.min(values)) / (np.max(values)-np.min(values)) # normalize
    
    umap.plot.points(c_embedding, values=values, theme=theme, ax=ax[0])
    #psm = plt.pcolormesh([values, values], cmap=mpl.cm.get_cmap(theme))
    #cb = plt.colorbar(psm, ax=ax)

for i, theme in zip(range(3), themes):
    if i == 1 or i == 2:
        continue
    c_embedding = copy.deepcopy(embedding)
    if i==0:
        c_embedding.embedding_ = embedding.embedding_[:all_cellcounts[i], ...]
    else:
        c_embedding.embedding_ = embedding.embedding_[all_cellcounts[i-1]:all_cellcounts[i], ...]
    values = saliency_cells_V2[i]
    values = (values-np.min(values)) / (np.max(values)-np.min(values)) # normalize
    
    umap.plot.points(c_embedding, values=values, theme=theme, ax=ax[1])
    #psm = plt.pcolormesh([values, values], cmap=mpl.cm.get_cmap(theme))
    #cb = plt.colorbar(psm, ax=ax)

# for i, theme in zip(range(3), themes):
#     if i == 1 or i == 2:
#         continue
#     c_embedding = copy.deepcopy(embedding)
#     if i==0:
#         c_embedding.embedding_ = embedding.embedding_[:all_cellcounts[i], ...]
#     else:
#         c_embedding.embedding_ = embedding.embedding_[all_cellcounts[i-1]:all_cellcounts[i], ...]

#     values = saliency_cells_V3[i]
#     values = (values-np.min(values)) / (np.max(values)-np.min(values)) # normalize
    
#     umap.plot.points(c_embedding, values=values, theme=theme, ax=ax[2])
#     #psm = plt.pcolormesh([values, values], cmap=mpl.cm.get_cmap(theme))
#     #cb = plt.colorbar(psm, ax=ax)
    
plt.show()

In [None]:
# Plot all saliency types next to each other
fig, ax = plt.subplots(3, 1, dpi=300)
for i, theme in zip(range(3), themes):
    if i == 0 or i == 1:
        continue
    c_embedding = copy.deepcopy(embedding)
    if i==0:
        c_embedding.embedding_ = embedding.embedding_[:all_cellcounts[i], ...]
    else:
        c_embedding.embedding_ = embedding.embedding_[all_cellcounts[i-1]:all_cellcounts[i], ...]
    values = saliency_cells_V1[i]
    values = (values-np.min(values)) / (np.max(values)-np.min(values)) # normalize
    
    umap.plot.points(c_embedding, values=values, theme=theme, ax=ax[0])
    #psm = plt.pcolormesh([values, values], cmap=mpl.cm.get_cmap(theme))
    #cb = plt.colorbar(psm, ax=ax)

for i, theme in zip(range(3), themes):
    if i == 0 or i == 1:
        continue
    c_embedding = copy.deepcopy(embedding)
    if i==0:
        c_embedding.embedding_ = embedding.embedding_[:all_cellcounts[i], ...]
    else:
        c_embedding.embedding_ = embedding.embedding_[all_cellcounts[i-1]:all_cellcounts[i], ...]
    values = saliency_cells_V2[i]
    values = (values-np.min(values)) / (np.max(values)-np.min(values)) # normalize
    
    umap.plot.points(c_embedding, values=values, theme=theme, ax=ax[1])
    #psm = plt.pcolormesh([values, values], cmap=mpl.cm.get_cmap(theme))
    #cb = plt.colorbar(psm, ax=ax)

# for i, theme in zip(range(3), themes):
#     if i == 0 or i == 1:
#         continue
#     c_embedding = copy.deepcopy(embedding)
#     if i==0:
#         c_embedding.embedding_ = embedding.embedding_[:all_cellcounts[i], ...]
#     else:
#         c_embedding.embedding_ = embedding.embedding_[all_cellcounts[i-1]:all_cellcounts[i], ...]

#     values = saliency_cells_V3[i]
#     values = (values-np.min(values)) / (np.max(values)-np.min(values)) # normalize
    
#     umap.plot.points(c_embedding, values=values, theme=theme, ax=ax[2])
#     #psm = plt.pcolormesh([values, values], cmap=mpl.cm.get_cmap(theme))
#     #cb = plt.colorbar(psm, ax=ax)
    
plt.show()

In [None]:
saliency_type = 'V1'

fig, ax = plt.subplots(1, 1, dpi=300)
for i, theme in zip(range(3), themes):
    c_embedding = copy.deepcopy(embedding)
    if i==0:
        c_embedding.embedding_ = embedding.embedding_[:all_cellcounts[i], ...]
    else:
        c_embedding.embedding_ = embedding.embedding_[all_cellcounts[i-1]:all_cellcounts[i], ...]
    
    if saliency_type == 'LWR':
        values = saliency_cells_LWR[i]
    elif saliency_type == 'V0':
        values = saliency_cells_V0[i]
    elif saliency_type == 'V1':
        values = saliency_cells_V1[i]
    elif saliency_type == 'V2':
        values = saliency_cells_V2[i]
    elif saliency_type == 'V3':
        values = saliency_cells_V3[i]
       
    values = (values-np.min(values)) / (np.max(values)-np.min(values)) # normalize
    umap.plot.points(c_embedding, values=values, theme=theme, ax=ax)

plt.show()

In [None]:
# PLOT ALL SALIENCIES TOGETHER
saliency_cells_V4 = []

fig, ax = plt.subplots(1, 1, dpi=300)
for i, theme in zip(range(3), themes):
    c_embedding = copy.deepcopy(embedding)
    if i==0:
        c_embedding.embedding_ = embedding.embedding_[:all_cellcounts[i], ...]
    else:
        c_embedding.embedding_ = embedding.embedding_[all_cellcounts[i-1]:all_cellcounts[i], ...]
    
    V1 = saliency_cells_V1[i]
    V2 = saliency_cells_V2[i]
    V3 = saliency_cells_V3[i]
    V1 = (V1 - np.min(V1)) / (np.max(V1)-np.min(V1))
    V2 = (V2 - np.min(V2)) / (np.max(V2)-np.min(V2))
    V3 = (V3 - np.min(V3)) / (np.max(V3)-np.min(V3))
    
    values = V1 + V2 + V3
    saliency_cells_V4.append(values)
       
    umap.plot.points(c_embedding, values=values, theme=theme, ax=ax)

plt.show()

In [None]:
embedding.embedding_.shape

In [None]:
labels_new = []
cellnrs = [all_cellcounts[0]]+list(np.diff(all_cellcounts))
for k in range(3):
    labels_new = labels_new + [PERTURBATIONS[k]] *cellnrs[k]


hover_data = pd.DataFrame({'compound': labels_new})

fig = px.scatter(
    embedding.embedding_, x=0, y=1,
    color=hover_data.compound, labels={'color': 'compound'}
)

fig.update_xaxes(fixedrange=True)
fig.update_yaxes(fixedrange=True)
fig.update_layout(yaxis_range=[-14,8], xaxis_range=[3, 24])

fig.show()

### Saliency V0

In [None]:
#%% Gradient based saliency
subset_list = []
for i in range(3):
    values = saliency_cells_V0[i]
    values = (values-np.min(values)) / (np.max(values)-np.min(values))
    print(values[np.argpartition(values, -2)][-2:])
    subset = values>0.3
    subset_list.append(subset)
subset_list = np.concatenate(subset_list)

fig = px.scatter(
    embedding.embedding_[subset_list], x=0, y=1,
    color=hover_data.compound[subset_list], labels={'color': 'compound'}
)

fig.update_xaxes(fixedrange=True)
fig.update_yaxes(fixedrange=True)
fig.update_layout(yaxis_range=[-14,8], xaxis_range=[3, 24])

fig.show()

### Saliency V1

In [None]:
#%% Gradient based saliency
subset_list = []
for i in range(3):
    values = saliency_cells_V1[i]
    values = (values-np.min(values)) / (np.max(values)-np.min(values))
    subset = values>0.8
    subset_list.append(subset)
subset_list = np.concatenate(subset_list)

fig = px.scatter(
    embedding.embedding_[subset_list], x=0, y=1,
    color=hover_data.compound[subset_list], labels={'color': 'compound'}
)

fig.update_xaxes(fixedrange=True)
fig.update_yaxes(fixedrange=True)
fig.update_layout(yaxis_range=[-14,8], xaxis_range=[3, 24])

fig.show()

### Saliency V2

In [None]:
#%% Distance based saliency - single cell profile calculation
subset_list = []
for i in range(3):
    values = saliency_cells_V2[i]
    values = (values-np.min(values)) / (np.max(values)-np.min(values))
    subset = values>0.8
    subset_list.append(subset)
subset_list = np.concatenate(subset_list)

fig = px.scatter(
    embedding.embedding_[subset_list], x=0, y=1,
    color=hover_data.compound[subset_list], labels={'color': 'compound'}
)

fig.update_xaxes(fixedrange=True)
fig.update_yaxes(fixedrange=True)
fig.update_layout(yaxis_range=[-14,8], xaxis_range=[3, 24])

fig.show()

### Saliency V3

In [None]:
#%% Loss based saliency - leave on cell out
subset_list = []
for i in range(3):
    values = saliency_cells_V3[i]
    values = (values-np.min(values)) / (np.max(values)-np.min(values))
    subset = values>0.8
    subset_list.append(subset)
subset_list = np.concatenate(subset_list)

fig = px.scatter(
    embedding.embedding_[subset_list], x=0, y=1,
    color=hover_data.compound[subset_list], labels={'color': 'compound'}
)

fig.update_xaxes(fixedrange=True)
fig.update_yaxes(fixedrange=True)
fig.update_layout(yaxis_range=[-14,8], xaxis_range=[3, 24])

fig.show()

### Saliency V4

In [None]:
#%% Loss based saliency
subset_list = []
for i in range(3):
    values = saliency_cells_V4[i]
    values = (values-np.min(values)) / (np.max(values)-np.min(values))
    subset = values>0.8
    subset_list.append(subset)
subset_list = np.concatenate(subset_list)

fig = px.scatter(
    embedding.embedding_[subset_list], x=0, y=1,
    color=hover_data.compound[subset_list], labels={'color': 'compound'}
)

fig.update_xaxes(fixedrange=True)
fig.update_yaxes(fixedrange=True)
fig.update_layout(yaxis_range=[-14,8], xaxis_range=[3, 24])

fig.show()

In [None]:
print('\n'.join(topY_feature_list[0].split(',')))

## Visualize profiles versus single cells

In [87]:
# Configure UMAP
mlp_data = True
new_reducer = umap.UMAP(random_state=42, metric='cosine', n_neighbors=375)
if mlp_data:
    new_embedding = new_reducer.fit(np.concatenate([np.concatenate(all_single_cell_profiles), np.concatenate(all_final_profiles)]))
else:
    new_embedding = new_reducer.fit(np.concatenate([all_points, np.concatenate(all_mean_profiles)]))


In [24]:
saliency_cells_V0=[(values-np.min(values)) / (np.max(values)-np.min(values)) for values in saliency_cells_V0]
saliency_cells_V1=[(values-np.min(values)) / (np.max(values)-np.min(values)) for values in saliency_cells_V1]
saliency_cells_V4 = [v1 + v2 for v1,v2 in zip(saliency_cells_V0, saliency_cells_V1)]

In [91]:
threshold = True
if threshold:
    subset_list = []
    for i in range(len(saliency_cells_V1)):
        values = saliency_cells_V4[i]
        values = (values-np.min(values)) / (np.max(values)-np.min(values))
        subset = values>0.8
        subset_list.append(subset)
    subset_list = np.concatenate([np.concatenate(subset_list), np.array([True]*len(saliency_cells_V1))])


labels_new = []
cellnrs = [all_cellcounts[0]]+list(np.diff(all_cellcounts))
scatter_sizes = []
for k in range(len(saliency_cells_V1)):
    labels_new = labels_new + [PERTURBATIONS[k]] * cellnrs[k] 
    scatter_sizes = scatter_sizes + [0.5] * cellnrs[k]
    
labels_new = labels_new + PERTURBATIONS #['Profile_'+ z for z in PERTURBATIONS]
scatter_sizes = scatter_sizes + [7] * len(PERTURBATIONS)  # ['Profile_'+PERTURBATIONS[k]]
new_hover_data = pd.DataFrame({'labels': labels_new})

if threshold:
    fig = px.scatter(
        new_embedding.embedding_[subset_list], x=0, y=1,
        color=new_hover_data.labels[subset_list], labels={'color': 'labels'},
        size = np.array(scatter_sizes)[subset_list]
    )
else:
    fig = px.scatter(
        new_embedding.embedding_, x=0, y=1,
        color=new_hover_data.labels, labels={'color': 'labels'},
        size = scatter_sizes
    )

fig.update_xaxes(fixedrange=True)
fig.update_yaxes(fixedrange=True)

fig.show()
# print('Euclidean distance between skepinone-l and purmorphamine:', np.linalg.norm(all_mean_profiles[1]-all_mean_profiles[2]))
# print('Euclidean distance between skepinone-l and sirolimus:' ,np.linalg.norm(all_mean_profiles[0]-all_mean_profiles[2]))
# print('Euclidean distance between sirolimus and purmorphamine:' ,np.linalg.norm(all_mean_profiles[1]-all_mean_profiles[0]))

## Save the saliencies 

In [25]:
for idx, W, P in zip(range(len(WELLS)), WELLS, PERTURBATIONS):
    df = pd.DataFrame({'SaliencyV0': saliency_cells_V0[idx],
                       'SaliencyV1': saliency_cells_V1[idx],
                       'SaliencyV2': saliency_cells_V2[idx],
#                        'SaliencyV3': saliency_cells_V3[idx],
                       'SaliencyV4': saliency_cells_V4[idx],
                       'TopYfeats': topY_feature_list_of_list[idx]})
    df.to_csv(f'Saliencies/saliencies_{plates[0][:-3]}_{W}_{P}.csv')
    
    saliency_feats_dfs[idx].to_csv(f'Saliencies/last_model_feature_saliencies_{plates[0][:-3]}_{W}_{P}.csv')
    