# Imports 

In [1]:
from pathlib import Path
from phenoseeker import EmbeddingManager
import matplotlib.pyplot as plt
import umap
import matplotlib.patches as mpatches
from matplotlib.lines import Line2D
import pandas as pd

In [2]:
embeddings_path = Path("/home/maxime/data/jump_embeddings/embeddings_chada.npy")
metadata_path = Path("/home/maxime/data/jump_embeddings/metadata_chada.parquet")

In [3]:
selected_plates = [  #'UL000585',
    #'UL001773',
    #'GR00004405',
    #'UL000097',
    #'GR00003344',
    #'GR00003285',
    "1086292853",
    "EC000134",
    "B40703bW",
    "EC000065",
    "BR5873d3W",
    "J12424d",
    "1086291931",
    "EC000071",
    "110000296323",
    "AEOJUM504",
    "110000295571",
    "110000297103",
    #'1086293027',
    "A1170530",
    "Dest210726-161624",
    "Dest210809-135330",
    "A1170468",
    "A1170419",
    "APTJUM208",
    "Dest210823-180708",
    "AEOJUM902",
]

In [4]:
selected_plates = pd.read_json(
    "/home/maxime/synrepos/phenoseeker/scripts/balanced_plates_51_lab.json"
)["Metadata_Plate"].to_list()

FileNotFoundError: File /home/maxime/synrepos/phenoseeker/scripts/balanced_plates_51_lab.json does not exist

In [None]:
def plot_umap(embeddings, df, color_by_source=True, shape_by_JCP_id=True):
    """
    Plot a UMAP projection of the embeddings with inverted mapping:
      - Marker shapes are determined by 'Metadata_Source'
      - Colors are determined by 'Metadata_JCP2022'
    In the legend for the JCP groups, the names are replaced by "Positive Control 1", "Positive Control 2", etc.
    The legend titles are set as:
      - "Source Laboratories" for the marker shapes (Metadata_Source)
      - "Positive Control Compounds" for the colors (Metadata_JCP2022)
    The output image is generated with high resolution (dpi=300) suitable for a Nature publication.

    Parameters:
    -----------
    embeddings : numpy.ndarray
        Array of shape (n_samples, n_features) containing the embeddings.
    df : pandas.DataFrame
        DataFrame containing at least the following columns:
          - For marker shapes: 'Metadata_Source'
          - For colors: 'Metadata_JCP2022'
    color_by_source : bool, optional (default True)
        If True, marker shapes are assigned according to 'Metadata_Source' (8 distinct markers).
    shape_by_JCP_id : bool, optional (default True)
        If True, colors are assigned according to 'Metadata_JCP2022' (using a colormap).
    """
    import matplotlib.pyplot as plt
    import umap
    import matplotlib.patches as mpatches
    from matplotlib.lines import Line2D

    # Helper function to format source labels as "Laboratory XX"
    def format_lab_label(label):
        parts = label.split("_")
        if len(parts) > 1 and parts[1].isdigit():
            return "Laboratory " + parts[1].zfill(2)
        else:
            return label

    # --- Dimension reduction with UMAP ---
    reducer = umap.UMAP(n_components=2, random_state=42)
    embedding_2d = reducer.fit_transform(embeddings)
    df["UMAP1"] = embedding_2d[:, 0]
    df["UMAP2"] = embedding_2d[:, 1]

    # Create figure and axis with high dpi for high-quality output
    fig, ax = plt.subplots(figsize=(12, 10), dpi=1000)

    # --- Mapping inversion ---
    # Marker shapes now from 'Metadata_Source'
    if color_by_source:
        source_groups = sorted(df["Metadata_Source"].unique())
        markers = ["o", "s", "^", "v", "D", "P", "X", "*"]
        marker_dict = {
            group: markers[i % len(markers)] for i, group in enumerate(source_groups)
        }
    else:
        marker_dict = None

    # Colors now from 'Metadata_JCP2022'
    if shape_by_JCP_id:
        jcp_groups = sorted(df["Metadata_JCP2022"].unique())
        cmap_jcp = plt.get_cmap("tab10")
        color_dict = {group: cmap_jcp(i) for i, group in enumerate(jcp_groups)}
    else:
        color_dict = None

    # --- Plotting points ---
    if shape_by_JCP_id and color_by_source:
        # For each combination of Metadata_Source and Metadata_JCP2022, plot the points
        for source in source_groups:
            for jcp in jcp_groups:
                subset = df[
                    (df["Metadata_Source"] == source) & (df["Metadata_JCP2022"] == jcp)
                ]
                if not subset.empty:
                    ax.scatter(
                        subset["UMAP1"],
                        subset["UMAP2"],
                        marker=marker_dict[source],
                        color=color_dict[jcp],
                        edgecolor="k",
                        s=100,
                        alpha=0.8,
                    )
    elif shape_by_JCP_id and not color_by_source:
        # Only color mapping from JCP, fixed marker shape
        for jcp in jcp_groups:
            subset = df[df["Metadata_JCP2022"] == jcp]
            ax.scatter(
                subset["UMAP1"],
                subset["UMAP2"],
                marker="o",
                color=color_dict[jcp],
                edgecolor="k",
                s=100,
                alpha=0.8,
                label=f"Positive Control {jcp}",
            )
    elif not shape_by_JCP_id and color_by_source:
        # Only marker shapes from Source, fixed color
        for source in source_groups:
            subset = df[df["Metadata_Source"] == source]
            ax.scatter(
                subset["UMAP1"],
                subset["UMAP2"],
                marker=marker_dict[source],
                color="blue",
                edgecolor="k",
                s=100,
                alpha=0.8,
                label=format_lab_label(source),
            )
    else:
        ax.scatter(
            df["UMAP1"],
            df["UMAP2"],
            marker="o",
            color="blue",
            edgecolor="k",
            s=100,
            alpha=0.8,
        )

    # --- Customizing the plot ---
    ax.set_xlabel("UMAP1", fontsize=16)
    ax.set_ylabel("UMAP2", fontsize=16)
    ax.set_title("UMAP Projection of Embeddings", fontsize=18)
    ax.grid(True, linestyle="--", alpha=0.5)
    ax.tick_params(axis="both", labelsize=14)

    # --- Legends ---
    if shape_by_JCP_id and color_by_source:
        # Legend for marker shapes (Metadata_Source) with formatted labels ("Laboratory XX")
        handles_shape = [
            Line2D(
                [0],
                [0],
                marker=marker_dict[source],
                color="w",
                markerfacecolor="gray",
                markersize=10,
                markeredgecolor="k",
            )
            for source in source_groups
        ]
        legend_shape = ax.legend(
            handles=handles_shape,
            labels=[format_lab_label(source) for source in source_groups],
            fontsize=12,
            title="",
            title_fontsize=14,
            bbox_to_anchor=(1.05, 1),
            loc="upper left",
        )
        ax.add_artist(legend_shape)
        # Legend for colors (Metadata_JCP2022) with replaced names
        handles_color = [
            mpatches.Patch(color=color_dict[jcp], label=f"Positive Control {i+1}")
            for i, jcp in enumerate(jcp_groups)
        ]
        legend_color = ax.legend(
            handles=handles_color,
            fontsize=12,
            title="",
            title_fontsize=14,
            bbox_to_anchor=(1.05, 0.5),
            loc="upper left",
        )
    elif shape_by_JCP_id and not color_by_source:
        handles_color = [
            mpatches.Patch(color=color_dict[jcp], label=f"Positive Control {i+1}")
            for i, jcp in enumerate(jcp_groups)
        ]
        legend_color = ax.legend(
            handles=handles_color,
            title="",
            fontsize=12,
            title_fontsize=14,
            bbox_to_anchor=(1.05, 1),
            loc="upper left",
        )
    elif not shape_by_JCP_id and color_by_source:
        handles_shape = [
            Line2D(
                [0],
                [0],
                marker=marker_dict[source],
                color="w",
                markerfacecolor="gray",
                markersize=10,
                markeredgecolor="k",
            )
            for source in source_groups
        ]
        legend_shape = ax.legend(
            handles=handles_shape,
            labels=[format_lab_label(source) for source in source_groups],
            fontsize=12,
            title="Source Laboratories",
            title_fontsize=14,
            bbox_to_anchor=(1.05, 1),
            loc="upper left",
        )

    fig.subplots_adjust(right=0.75)
    plt.tight_layout()
    plt.show()

# Load and preprocess

In [None]:
well_em = EmbeddingManager(metadata_path, entity="well")
well_em.load("Embeddings_Raw", embeddings_path)

In [None]:
selected_plates = (
    well_em.df.drop_duplicates(["Metadata_Source", "Metadata_Plate"])
    .groupby("Metadata_Source")
    .sample(3)["Metadata_Plate"]
    .tolist()
)

In [None]:
selected_em = well_em.filter_and_instantiate(Metadata_Plate=selected_plates)
controls_em = selected_em.filter_and_instantiate(
    Metadata_JCP2022=well_em.JCP_ID_controls
)

In [None]:
# Calculer le nombre de lignes par plaque
plate_counts = controls_em.df["Metadata_Plate"].value_counts()

# Identifier les plaques avec moins de 100 lignes
plates_to_keep = plate_counts[plate_counts < 100].index

# Filtrer la DataFrame pour conserver uniquement ces plaques
controls_em = controls_em.filter_and_instantiate(Metadata_Plate=list(plates_to_keep))

In [None]:
controls_em.df["Metadata_Plate"].nunique()

In [None]:
# controls_em.compute_maps('Metadata_Plate', ['Embeddings_Raw'], random_maps=True)

In [None]:
# controls_em.compute_maps('Metadata_Source', ['Embeddings_Raw'], random_maps=True)

In [None]:
# poscon_em  = controls_em.filter_and_instantiate(Metadata_JCP2022=controls_em.JCP_ID_poscon)
# poscon_em.compute_maps('Metadata_JCP2022', ['Embeddings_Raw'], random_maps=True)

# Normalise

In [None]:
poscon_em = controls_em.filter_and_instantiate(
    Metadata_JCP2022=controls_em.JCP_ID_poscon
)
df = poscon_em.df
embeddings = poscon_em.embeddings["Embeddings_Raw"]

plot_umap(embeddings, df, shape_by_JCP_id=True)

In [None]:
controls_em.apply_spherizing_transform(
    "Embeddings_Raw", "Embeddings_Raw_ZCA_N_C", "ZCA", True, True
)
poscon_em = controls_em.filter_and_instantiate(
    Metadata_JCP2022=controls_em.JCP_ID_poscon
)
df = poscon_em.df
embeddings = poscon_em.embeddings["Embeddings_Raw_ZCA_N_C"]

plot_umap(embeddings, df, shape_by_JCP_id=True)

In [None]:
controls_em.apply_spherizing_transform(
    "Embeddings_Raw", "Embeddings__ZCA_C", "ZCA", False, True
)
controls_em.apply_inverse_normal_transform(
    "Embeddings__ZCA_C", "Embeddings__ZCA_C__Int"
)

poscon_em = controls_em.filter_and_instantiate(
    Metadata_JCP2022=controls_em.JCP_ID_poscon
)
df = poscon_em.df
embeddings = poscon_em.embeddings["Embeddings__ZCA_C__Int"]

plot_umap(embeddings, df, shape_by_JCP_id=True)