# Imports 

In [None]:
from pathlib import Path
from phenoseeker import EmbeddingManager
import matplotlib.pyplot as plt
import umap
import matplotlib.patches as mpatches
from matplotlib.lines import Line2D
import pandas as pd


In [None]:
embeddings_path = Path('/home/maxime/data/jump_embeddings/embeddings_dinov2_g.npy')
metadata_path = Path('/home/maxime/data/jump_embeddings/metadata_dinov2_g.parquet')

In [5]:
selected_plates = [#'UL000585',
 #'UL001773',
 #'GR00004405',
 #'UL000097',
 #'GR00003344',
 #'GR00003285',
 '1086292853',
 'EC000134',
 'B40703bW',
 'EC000065',
 'BR5873d3W',
 'J12424d',
 '1086291931',
 'EC000071',
 '110000296323',
 'AEOJUM504',
 '110000295571',
 '110000297103',
 #'1086293027',
 'A1170530',
 'Dest210726-161624',
 'Dest210809-135330',
 'A1170468',
 'A1170419',
 'APTJUM208',
 'Dest210823-180708',
 'AEOJUM902']

In [5]:
selected_plates = pd.read_json(
            "/home/maxime/synrepos/phenoseeker/scripts/balanced_plates_51_lab.json"
        )["Metadata_Plate"].to_list()

In [6]:
def plot_umap(embeddings, df, color_by_source=True, shape_by_JCP_id=True):
    """
    Plot a UMAP projection of the embeddings with optional coloring and shaping.
    
    Parameters:
    -----------
    embeddings : numpy.ndarray
        Array of shape (n_samples, n_features) containing the embeddings.
    df : pandas.DataFrame
        DataFrame containing at least the following columns:
          - If shape_by_JCP_id is True: 'Metadata_JCP2022'
          - If color_by_source is True: 'Metadata_Source'
    color_by_source : bool, optional (default True)
        If True, each point is colored according to its value in 'Metadata_Source' (7 distinct colors).
    shape_by_JCP_id : bool, optional (default True)
        If True, each point is drawn with a marker shape according to its value in 'Metadata_JCP2022' (8 distinct shapes).
    """
    import matplotlib.pyplot as plt
    import umap
    import matplotlib.patches as mpatches
    from matplotlib.lines import Line2D

    # --- Réduction de dimension avec UMAP ---
    reducer = umap.UMAP(n_components=2, random_state=42)
    embedding_2d = reducer.fit_transform(embeddings)
    df['UMAP1'] = embedding_2d[:, 0]
    df['UMAP2'] = embedding_2d[:, 1]
    
    # Création de la figure et de l'axe
    fig, ax = plt.subplots(figsize=(12, 10))
    
    # Préparation des marqueurs si shape_by_JCP_id est activé
    if shape_by_JCP_id:
        jcp_groups = sorted(df['Metadata_JCP2022'].unique())
        markers = ['o', 's', '^', 'v', 'D', 'P', 'X', '*']
        marker_dict = {group: markers[i % len(markers)] for i, group in enumerate(jcp_groups)}
    else:
        marker_dict = None

    # Préparation des couleurs si color_by_source est activé
    if color_by_source:
        source_groups = sorted(df['Metadata_Source'].unique())
        cmap_source = plt.get_cmap('tab10')
        source_colors = {source: cmap_source(i) for i, source in enumerate(source_groups)}
    else:
        source_colors = None

    # --- Tracé des points ---
    if shape_by_JCP_id and color_by_source:
        for group in jcp_groups:
            subset = df[df['Metadata_JCP2022'] == group]
            ax.scatter(
                subset['UMAP1'],
                subset['UMAP2'],
                marker=marker_dict[group],
                c=subset['Metadata_Source'].map(source_colors),
                edgecolor='k',
                s=100,
                alpha=0.8,
                label=str(group)
            )
    elif shape_by_JCP_id and not color_by_source:
        for group in jcp_groups:
            subset = df[df['Metadata_JCP2022'] == group]
            ax.scatter(
                subset['UMAP1'],
                subset['UMAP2'],
                marker=marker_dict[group],
                color="blue",
                edgecolor='k',
                s=100,
                alpha=0.8,
                label=str(group)
            )
    elif not shape_by_JCP_id and color_by_source:
        for source in source_groups:
            subset = df[df['Metadata_Source'] == source]
            ax.scatter(
                subset['UMAP1'],
                subset['UMAP2'],
                marker="o",
                color=source_colors[source],
                edgecolor='k',
                s=100,
                alpha=0.8,
                label=str(source)
            )
    else:
        ax.scatter(
            df['UMAP1'],
            df['UMAP2'],
            marker="o",
            color="blue",
            edgecolor='k',
            s=100,
            alpha=0.8
        )

    # --- Personnalisation du graphique ---
    ax.set_xlabel("UMAP1", fontsize=16)
    ax.set_ylabel("UMAP2", fontsize=16)
    ax.set_title("UMAP Projection of Embeddings", fontsize=18)
    ax.grid(True, linestyle='--', alpha=0.5)
    # Augmenter la taille des labels des axes
    ax.tick_params(axis='both', labelsize=14)
    
    # Fonction auxiliaire pour formater les labels de source
    def format_lab_label(label):
        parts = label.split('_')
        if len(parts) > 1 and parts[1].isdigit():
            return "Laboratory " + parts[1].zfill(2)
        else:
            return label

    # Création des légendes à l'extérieur du plot (à droite)
    if shape_by_JCP_id and color_by_source:
        # Légende pour les formes (Metadata_JCP2022)
        handles_shape = [
            Line2D([0], [0], marker=marker_dict[group], color='w',
                   markerfacecolor='gray', markersize=10, markeredgecolor='k')
            for group in jcp_groups
        ]
        legend_shape = ax.legend(
            handles=handles_shape, labels=jcp_groups,
            title="Metadata_JCP2022", fontsize=12, title_fontsize=14,
            bbox_to_anchor=(1.05, 1), loc="upper left"
        )
        ax.add_artist(legend_shape)
        # Légende pour les couleurs (Metadata_Source) avec les labels formatés
        handles_color = [
            mpatches.Patch(color=source_colors[source], label=format_lab_label(source))
            for source in source_groups
        ]
        legend_color = ax.legend(
            handles=handles_color,
            title="Metadata_Source", fontsize=12, title_fontsize=14,
            bbox_to_anchor=(1.05, 0.5), loc="upper left"
        )
    elif shape_by_JCP_id and not color_by_source:
        handles_shape = [
            Line2D([0], [0], marker=marker_dict[group], color='w',
                   markerfacecolor='gray', markersize=10, markeredgecolor='k')
            for group in jcp_groups
        ]
        legend_shape = ax.legend(
            handles=handles_shape, labels=jcp_groups,
            title="Metadata_JCP2022", fontsize=12, title_fontsize=14,
            bbox_to_anchor=(1.05, 1), loc="upper left"
        )
    elif not shape_by_JCP_id and color_by_source:
        handles_color = [
            mpatches.Patch(color=source_colors[source], label=format_lab_label(source))
            for source in source_groups
        ]
        legend_color = ax.legend(
            handles=handles_color,
            title="Metadata_Source", fontsize=12, title_fontsize=14,
            bbox_to_anchor=(1.05, 1), loc="upper left"
        )
    
    fig.subplots_adjust(right=0.75)
    plt.tight_layout()
    plt.show()


# Load and preprocess

In [7]:
well_em = EmbeddingManager(metadata_path, entity="well")
well_em.load("Embeddings_Raw", embeddings_path)

In [8]:
selected_em = well_em.filter_and_instantiate(Metadata_Plate=selected_plates)
controls_em = selected_em.filter_and_instantiate(Metadata_JCP2022=well_em.JCP_ID_controls)

In [None]:
controls_em.df['Metadata_Plate'].nunique()

In [None]:
controls_em.compute_maps('Metadata_Plate', ['Embeddings_Raw'], random_maps=True)

In [None]:
controls_em.compute_maps('Metadata_Source', ['Embeddings_Raw'], random_maps=True)

In [None]:
poscon_em  = controls_em.filter_and_instantiate(Metadata_JCP2022=controls_em.JCP_ID_poscon)
poscon_em.compute_maps('Metadata_JCP2022', ['Embeddings_Raw'], random_maps=True)

# Normalise

In [None]:
poscon_em  = controls_em.filter_and_instantiate(Metadata_JCP2022=controls_em.JCP_ID_poscon)
df = poscon_em.df
embeddings = poscon_em.embeddings['Embeddings_Raw']

plot_umap(embeddings, df, shape_by_JCP_id=True)

In [None]:
controls_em.apply_spherizing_transform('Embeddings_Raw', 'Embeddings_Raw_ZCA_N_C', "ZCA", True, True)

In [None]:
poscon_em  = controls_em.filter_and_instantiate(Metadata_JCP2022=controls_em.JCP_ID_poscon)
df = poscon_em.df
embeddings = poscon_em.embeddings['Embeddings_Raw_ZCA_N_C']

plot_umap(embeddings, df, shape_by_JCP_id=True)

In [None]:
controls_em.apply_robust_Z_score('Embeddings_Raw', 'Embeddings_rZMi_C', True, 'mean', 'iqrs')
controls_em.apply_spherizing_transform('Embeddings_rZMi_C', 'Embeddings_rZMi_C__ZCA-cor_N_C', "ZCA-cor", True, True)
controls_em.apply_inverse_normal_transform('Embeddings_rZMi_C__ZCA-cor_N_C', 'Embeddings_rZMi_C__ZCA-cor_N_C__Int')

In [None]:
poscon_em  = controls_em.filter_and_instantiate(Metadata_JCP2022=controls_em.JCP_ID_poscon)
df = poscon_em.df
embeddings = poscon_em.embeddings['Embeddings_rZMi_C__ZCA-cor_N_C__Int']

plot_umap(embeddings, df, shape_by_JCP_id=True)

In [None]:
controls_em.apply_robust_Z_score('Embeddings_Raw', 'Embeddings_rZMi_C', True, 'mean', 'iqrs')
controls_em.apply_spherizing_transform('Embeddings_Raw', 'Embeddings__ZCA_C', "ZCA", False, True)
controls_em.apply_inverse_normal_transform('Embeddings__ZCA_C', 'Embeddings__ZCA_C__Int')

In [None]:
poscon_em  = controls_em.filter_and_instantiate(Metadata_JCP2022=controls_em.JCP_ID_poscon)
df = poscon_em.df
embeddings = poscon_em.embeddings['Embeddings__ZCA_C__Int']

plot_umap(embeddings, df, shape_by_JCP_id=True)

In [None]:
controls_em.apply_robust_Z_score('Embeddings_Raw', 'Embeddings_rZMi_C', True, 'mean', 'iqrs')
controls_em.apply_spherizing_transform('Embeddings_Raw', 'Embeddings_Raw__ZCA_N_C', "ZCA", True, True)
controls_em.apply_inverse_normal_transform('Embeddings_Raw__ZCA_N_C', 'Embeddings_Raw__ZCA_N_C__Int')

In [None]:
poscon_em  = controls_em.filter_and_instantiate(Metadata_JCP2022=controls_em.JCP_ID_poscon)
df = poscon_em.df
embeddings = poscon_em.embeddings['Embeddings_Raw__ZCA_N_C']

plot_umap(embeddings, df, shape_by_JCP_id=True)

In [None]:
controls_em.apply_robust_Z_score('Embeddings_Raw', 'Embeddings_rZMi_C', True, 'mean', 'iqrs')
controls_em.apply_spherizing_transform('Embeddings_Raw', 'Embeddings__ZCA_C', "ZCA", False, True)
controls_em.apply_inverse_normal_transform('Embeddings__ZCA_C', 'Embeddings__ZCA_C__Int')

In [None]:
poscon_em  = controls_em.filter_and_instantiate(Metadata_JCP2022=controls_em.JCP_ID_poscon)
df = poscon_em.df
embeddings = poscon_em.embeddings['Embeddings__ZCA_C__Int']

plot_umap(embeddings, df, shape_by_JCP_id=True)