# Import 

In [6]:
import pandas as pd
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
import numpy as np
from tqdm import tqdm

In [7]:
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster
from sklearn.metrics.pairwise import cosine_similarity
from joblib import Parallel, delayed
import numpy as np

import pandas as pd
from rdkit import Chem
from rdkit.Chem import Draw


def compute_enrichment_factor_at_n(hit_embedding, all_embeddings, all_labels, n_percent=1):
    """
    Calculate the enrichment factor (EF) at n% for a given 'hit' embedding, excluding the control positive.

    Parameters:
    - hit_embedding: numpy array, the embedding of the 'hit' sample.
    - all_embeddings: numpy array, embeddings of all samples.
    - all_labels: numpy array, binary labels for all samples (1 for 'hit', 0 otherwise).
    - n_percent: float, the percentage (0-100) of the dataset to consider for EF calculation.

    Returns:
    - float, the Enrichment Factor at n%.
    """
    # Compute cosine similarity of the hit with all samples
    similarities = cosine_similarity(hit_embedding.reshape(1, -1), all_embeddings).flatten()

    # Rank indices based on similarity
    ranked_indices = np.argsort(-similarities)  # Descending order
    ranked_labels = all_labels[ranked_indices]  # Get labels of ranked samples

    # Remove the first entry (control positive)
    ranked_indices = ranked_indices[1:]  # Exclude the first index
    ranked_labels = ranked_labels[1:]    # Exclude the first label

    # Calculate top n% cutoff
    n_top = max(1, int(len(ranked_labels) * (n_percent / 100)))  # At least 1 sample

    # Count hits in the top n% of ranked samples
    hits_in_top_n = np.sum(ranked_labels[:n_top])

    # Total hits in the dataset
    total_hits = np.sum(all_labels)

    # Compute EF
    if total_hits == 0:  # Avoid division by zero
        return 0.0
    enrichment_factor = (hits_in_top_n / n_top) / (total_hits / len(all_labels))

    return enrichment_factor
def evaluate_EF_df(df_test, n):
    all_labels = (df_test['Active'] == True).astype(int).values
    hits = df_test[df_test['Active']]

    hits_embeddings = np.stack(hits['Embeddings_mean'].values).astype(np.float16)
    all_embeddings = np.stack(df_test['Embeddings_mean'].values).astype(np.float16)
    enrichment_factor = Parallel(n_jobs=40)(
        delayed(compute_enrichment_factor_at_n)(hit_embedding, all_embeddings, all_labels, n)
        for hit_embedding in tqdm(hits_embeddings)
    )

    mEF = np.mean(enrichment_factor)
    maxEF = np.max(enrichment_factor)
    return mEF, maxEF

def compute_phenotypic_similarity(df):
    embeddings = np.stack(df['Embeddings_mean'])  
    similarity_matrix = cosine_similarity(embeddings)
    return similarity_matrix

def hierarchical_clustering_and_visualization(df, similarity_matrix, threshold=0.5):
    """
    Effectue un clustering hiérarchique sur les molécules et visualise les résultats.
    
    Args:
        df (pd.DataFrame): DataFrame contenant les molécules et leurs métadonnées.
        similarity_matrix (np.ndarray): Matrice de similarité entre les molécules.
        threshold (float): Seuil pour déterminer les clusters à partir du dendrogramme.
    """
    # Convertir la matrice de similarité en matrice de dissimilarité
    dissimilarity = 1 - similarity_matrix

    # Créer un linkage pour le clustering hiérarchique
    linkage_matrix = linkage(dissimilarity, method='average')

    # Afficher le dendrogramme
    plt.figure(figsize=(10, 7))
    dendrogram(linkage_matrix, labels=df['Metadata_JCP2022'].values, leaf_rotation=90)
    plt.title("Dendrogramme de clustering hiérarchique")
    plt.xlabel("Molécules")
    plt.ylabel("Distance")
    plt.axhline(y=threshold, color='r', linestyle='--', label=f'Seuil {threshold}')
    plt.legend()
    plt.show()

    # Déterminer les clusters à partir du seuil
    clusters = fcluster(linkage_matrix, t=threshold, criterion='distance')
    df['Cluster'] = clusters
    df_sorted = df.sort_values(by='Cluster').reset_index(drop=True)

    # Visualisation des molécules par cluster
    for cluster_id in sorted(df['Cluster'].unique()):
        cluster_df = df[df['Cluster'] == cluster_id]
        mols = [Chem.MolFromInchi(inchi) for inchi in cluster_df['Metadata_InChI']]
        legends = list(cluster_df['Metadata_JCP2022'].astype(str))

        print(f"Cluster {cluster_id} - {len(cluster_df)} molécules")
        img = Draw.MolsToGridImage(mols, molsPerRow=5, subImgSize=(200, 200), legends=legends)
        display(img)
    return df_sorted

def visualize_similarity_matrix(df, similarity_matrix, fontsize_sim=14):
    """
    Visualise la matrice de similarité cosinus avec les valeurs affichées.

    Args:
        df: pandas DataFrame contenant la colonne Metadata_JCP2022.
        similarity_matrix: numpy array, matrice de similarité cosinus.
    """
    ids = df['Metadata_JCP2022'].values

    fig, ax = plt.subplots(figsize=(10, 8))

    # Matrice de similarité
    cax = ax.imshow(similarity_matrix, cmap='viridis', interpolation='none')  # Suppression de l'interpolation
    ax.set_title('Cosine Similarity Matrix', fontsize=14)
    ax.set_xticks(range(len(ids)))
    ax.set_yticks(range(len(ids)))
    ax.set_xticklabels(ids, rotation=90, fontsize=8)
    ax.set_yticklabels(ids, fontsize=8)

    # Supprimer les lignes blanches entre les pixels
    ax.set_xticks([], minor=True)
    ax.set_yticks([], minor=True)
    ax.grid(False)  # Désactiver les grilles

    # Ajouter les valeurs dans les cases
    for i in range(similarity_matrix.shape[0]):
        for j in range(similarity_matrix.shape[1]):
            value = f"{similarity_matrix[i, j]:.2f}"  # Formater à 2 décimales
            ax.text(j, i, value, ha='center', va='center', fontsize=fontsize_sim, color='white')

    # Ajouter une barre de couleur
    plt.colorbar(cax, ax=ax)

    plt.tight_layout()
    plt.show()

In [8]:
#df_phenom = pd.read_parquet('/home/maxime/data/jump_embeddings/metadata_dinov2_g.parquet')
df_phenom = pd.read_parquet('/projects/synsight/data/jump_embeddings/wells_embeddings/openphenom/metadata_openphenom.parquet')

df_jump = df_phenom[["Metadata_JCP2022", "Metadata_InChI"]].drop_duplicates().reset_index()

In [9]:
df_phenom = pd.read_parquet('/projects/synsight/data/openphenom/norm_2_compounds_embeddings.parquet')


In [None]:
mg = AllChem.GetMorganGenerator(radius=2, fpSize=2048, includeChirality=False)

In [11]:
def inchi_to_fp(inchi):
    """Convert InChI string to RDKit Morgan fingerprint."""
    mol = Chem.MolFromInchi(inchi)
    if mol:
        return mg.GetFingerprint(mol)
    return None
 
def smiles_to_fp(smiles):
    """Convert SMILES to RDKit fingerprint."""
    mol = Chem.MolFromSmiles(smiles)
    if mol:
        return mg.GetFingerprint(mol)
    return None

def bulk_tanimoto_similarity(query_fp, list_of_fps):
    """Compute Tanimoto similarity efficiently in bulk."""
    list_of_fps = list(list_of_fps)  # Ensure it's a Python list
    return DataStructs.BulkTanimotoSimilarity(query_fp, list_of_fps)

def compute_similarity(query_fp, list_of_fps_jump):
    """Compute Tanimoto similarity between a query InChI and a list of InChIs."""

    if query_fp is None:
        raise ValueError("Invalid query")
    query_fp = smiles_to_fp(query_fp)
    list_of_fps = [fp for fp in list_of_fps_jump if fp is not None]  # Filter out None values
    
    return bulk_tanimoto_similarity(query_fp, list_of_fps)

In [None]:
list_of_fps_jump = [inchi_to_fp(inchi) for inchi in tqdm(df_jump['Metadata_InChI'].to_list())]
df_jump['Fps'] = list_of_fps_jump
df_jump.dropna(subset='Fps', inplace=True)

# Import mols from lit-pcba

In [19]:
import os
import pandas as pd

def load_smi_files(base_path):
    """
    Charge les fichiers actives.smi et inactives.smi d'un dossier et retourne un dictionnaire de DataFrames.

    Args:
        base_path (str): Le chemin vers le dossier contenant les sous-dossiers avec les fichiers .smi.

    Returns:
        dict: Un dictionnaire où chaque clé est le nom du sous-dossier et la valeur est un DataFrame.
    """
    data_dict = {}

    # Liste tous les sous-dossiers
    for folder in os.listdir(base_path):
        folder_path = os.path.join(base_path, folder)

        # Vérifie si c'est bien un dossier
        if os.path.isdir(folder_path):
            actives_file = os.path.join(folder_path, "actives_final.ism")
            inactives_file = os.path.join(folder_path, "decoys_final.ism")

            all_data = []

            # Lire actives.smi
            if os.path.exists(actives_file):
                df_actives = pd.read_csv(actives_file, sep=" ", names=["smiles", "id_dude", 'ChEMBL_id'])
                df_actives["Active"] = True
                all_data.append(df_actives[["smiles", "id_dude", "Active"]])

            # Lire inactives.smi
            if os.path.exists(inactives_file):
                df_inactives = pd.read_csv(inactives_file, sep=" ", names=["smiles", "id_dude"])
                df_inactives["Active"] = False
                all_data.append(df_inactives)

            # Si on a des données, on les stocke
            if all_data:
                data_dict[folder] = pd.concat(all_data, ignore_index=True)

    return data_dict

# Chemin vers ton dossier "data"
base_path = "../data/DUDE"

# Charger les données
data_dict = load_smi_files(base_path)


In [None]:
unique_smiles = []
for key, df in tqdm(data_dict.items()):
    unique_smiles = unique_smiles + df["smiles"].tolist()

In [None]:
len(unique_smiles)

In [None]:
data_dict['aa2ar']['smiles'].value_counts()

In [10]:
import os
import pandas as pd
from pathlib import Path
def load_parquets_files(base_path):
    """
    Charge les fichiers actives.smi et inactives.smi d'un dossier et retourne un dictionnaire de DataFrames.

    Args:
        base_path (str): Le chemin vers le dossier contenant les sous-dossiers avec les fichiers .smi.

    Returns:
        dict: Un dictionnaire où chaque clé est le nom du sous-dossier et la valeur est un DataFrame.
    """
    data_dict = {}
    for file in os.listdir(base_path):
        if Path(file).suffix == '.parquet':
            name = Path(file).stem.split('_')[1:][0]
            df = pd.read_parquet(Path(base_path) / Path(file))
            data_dict[name] = df

    return data_dict

# Chemin vers ton dossier "data"
base_path = "../scripts"

# Charger les données
data_dict_jump = load_parquets_files(base_path)

In [None]:
import math
import matplotlib.pyplot as plt
import seaborn as sns

sns.set(style="whitegrid")

# Determine the number of subplots needed
n_keys = len(data_dict)
n_cols = 6  # adjust number of columns as needed
n_rows = math.ceil(n_keys / n_cols)

# Create subplots
fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, n_rows * 4))
# If there is only one row or one plot, ensure axes is iterable
if n_keys == 1:
    axes = [axes]
else:
    axes = axes.flatten()

# Loop through each key and dataframe and plot on the corresponding subplot axis
for ax, (key, df) in zip(axes, data_dict.items()):
    active_counts = df['Active'].value_counts().reset_index()
    active_counts.columns = ['Active', 'Count']

    sns.barplot(x='Active', y='Count', data=active_counts, palette='viridis', ax=ax)
    ax.set_title(key)
    ax.set_xlabel("Active Status")
    ax.set_ylabel("Count")
    
    # Add text labels on top of each bar
    for index, row in active_counts.iterrows():
        ax.text(index, row['Count'], row['Count'], color='black', ha="center", va='bottom')

# Remove any empty subplots if number of keys is less than n_rows*n_cols
for ax in axes[len(data_dict):]:
    ax.remove()

plt.tight_layout()
plt.show()


In [113]:
similarity_threshold = 1
filtered_data = {}
# Afficher un exemple
for key, df in data_dict_jump.items():
    filtered_df = df[df['tanimoto_similarity'] >= similarity_threshold].drop_duplicates(subset='smiles')
    filtered_df = filtered_df.merge(df_phenom, left_on='closest_jcp', right_on='Metadata_JCP2022')[['id_lit_pcba', 'Active', 'Metadata_JCP2022', 'Embeddings_mean', 'smiles', 'Metadata_InChI']]
    filtered_data[key] = filtered_df

In [None]:
import math
import matplotlib.pyplot as plt
import seaborn as sns

sns.set(style="whitegrid")

# Determine the number of subplots needed
n_keys = len(filtered_data)
n_cols = 4  # adjust number of columns as needed
n_rows = math.ceil(n_keys / n_cols)

# Create subplots
fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, n_rows * 4))
# If there is only one row or one plot, ensure axes is iterable
if n_keys == 1:
    axes = [axes]
else:
    axes = axes.flatten()

# Loop through each key and dataframe and plot on the corresponding subplot axis
for ax, (key, df) in zip(axes, filtered_data.items()):
    active_counts = df['Active'].value_counts().reset_index()
    active_counts.columns = ['Active', 'Count']

    sns.barplot(x='Active', y='Count', data=active_counts, palette='viridis', ax=ax)
    ax.set_title(key)
    ax.set_xlabel("Active Status")
    ax.set_ylabel("Count")
    
    # Add text labels on top of each bar
    for index, row in active_counts.iterrows():
        ax.text(index, row['Count'], row['Count'], color='black', ha="center", va='bottom')

# Remove any empty subplots if number of keys is less than n_rows*n_cols
for ax in axes[len(filtered_data):]:
    ax.remove()

plt.tight_layout()
plt.show()


In [None]:
import math
import matplotlib.pyplot as plt
import seaborn as sns

sns.set(style="whitegrid")

# Determine the number of subplots needed
n_keys = len(filtered_active_data)
n_cols = 3  # adjust number of columns as needed
n_rows = math.ceil(n_keys / n_cols)

# Create subplots
fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, n_rows * 4))
# If there is only one row or one plot, ensure axes is iterable
if n_keys == 1:
    axes = [axes]
else:
    axes = axes.flatten()

# Loop through each key and dataframe and plot on the corresponding subplot axis
for ax, (key, df) in zip(axes, filtered_active_data.items()):
    active_counts = df['Active'].value_counts().reset_index()
    active_counts.columns = ['Active', 'Count']

    sns.barplot(x='Active', y='Count', data=active_counts, palette='viridis', ax=ax)
    ax.set_title(key)
    ax.set_xlabel("Active Status")
    ax.set_ylabel("Count")
    
    # Add text labels on top of each bar
    for index, row in active_counts.iterrows():
        ax.text(index, row['Count'], row['Count'], color='black', ha="center", va='bottom')

# Remove any empty subplots if number of keys is less than n_rows*n_cols
for ax in axes[len(filtered_data):]:
    ax.remove()

plt.tight_layout()
plt.show()


In [114]:
n = 10

# Filter the dictionary:
# This comprehension creates a new dictionary (filtered_active_data) containing only the keys
# for which the sum of the 'Active' column (i.e. count of active rows) is less than n.
filtered_active_data = {
    key: df for key, df in filtered_data.items() if df['Active'].sum() > n
}


# Explore

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Assuming evaluate_EF_df is defined as, for example:
# def evaluate_EF_df(df, percent):
#     # Compute the mean EF (mEF) and the maximum EF (maxEF) for the given percentage of selected molecules
#     return mEF, maxEF

# Calculate the enrichment factors for each DataFrame in filtered_active_data
results = []
for key, df in filtered_active_data.items():
    mean_EF, max_EF = evaluate_EF_df(df, 5)  # Here, 1 represents 1% of selected molecules
    results.append({'Key': key, 'Mean EF': mean_EF, 'Max EF': max_EF})

# Convert the list of dictionaries into a DataFrame
results_df = pd.DataFrame(results)

# Transform the DataFrame to a long format suitable for seaborn's barplot
results_long = results_df.melt(id_vars='Key', value_vars=['Mean EF', 'Max EF'], 
                               var_name='EF Type', value_name='EF Value')

# Set a style for the plots
sns.set(style="whitegrid")

# Create the bar plot
plt.figure(figsize=(10, 6))
bar_plot = sns.barplot(data=results_long, x='Key', y='EF Value', hue='EF Type', palette='viridis')

# Add a red horizontal dashed line at EF = 1 (the random baseline)
plt.axhline(y=1, color='red', linestyle='--', label='Random EF = 1')

# Customize the plot
plt.title("Enrichment Factors (5% of Selected Molecules)")
plt.xlabel("Key")
plt.ylabel("Enrichment Factor")
plt.xticks(rotation=45)

# Retrieve current legend handles and labels
handles, labels = plt.gca().get_legend_handles_labels()

# If the red line's label is not already included, add it manually
if 'Random EF = 1' not in labels:
    from matplotlib.lines import Line2D
    handles.append(Line2D([0], [0], color='red', linestyle='--'))
    labels.append('Random EF = 1')

plt.legend(handles=handles, labels=labels, title='EF Type')
plt.tight_layout()
plt.show()


In [None]:
for key, df in filtered_active_data.items():
    print(key)
    mEF, maxEF = evaluate_EF_df(df, 5)
    print(f"Mean Enrichment Factor (mEF): {mEF}")
    print(f"Max Enrichment Factor (mEF): {maxEF}")
    print('-'*50)

In [None]:
for key, df in filtered_active_data.items():
    print(key)
    mEF, maxEF = evaluate_EF_df(df, 3)
    print(f"Mean Enrichment Factor (mEF): {mEF}")
    print(f"Max Enrichment Factor (mEF): {maxEF}")
    print('-'*50)

In [128]:
df = filtered_active_data['VDR']
df = df[df['Active']]

In [None]:
similarity_matrix = compute_phenotypic_similarity(df)
df = hierarchical_clustering_and_visualization(df, similarity_matrix, threshold=1.7)
similarity_matrix = compute_phenotypic_similarity(df)
visualize_similarity_matrix(df, similarity_matrix, fontsize_sim=6)
