# Imports

In [None]:
from phenoseeker import EmbeddingManager
from pathlib import Path
import pandas as pd
import numpy as np

In [None]:
df_test_meta = pd.read_parquet("/projects/cpjump1/jump/load_data/test_training.parquet")
df_meta = pd.read_parquet("/projects/cpjump1/jump/load_data/final")
df_small_meta = pd.read_parquet("/projects/cpjump1/jump/load_data/small_training.parquet")

In [None]:
df = pd.read_parquet("/projects/cpjump1/jump/load_data/load_data_with_metadata/Metadata_Source=source_6")

In [None]:
from pathlib import Path
from phenoseeker import EmbeddingManager


class EmbeddingsEvaluator:
    """
    A class that encapsulates the evaluation pipeline for embeddings.

    The pipeline includes:
      - Loading the embeddings using a specified entity and embeddings name.
      - Grouping the embeddings by a specified key (e.g., "well") and keeping selected
        columns.
      - Applying a spherizing transform and an inverse normalization transform.
      - Filtering embeddings using a provided condition.
      - Computing evaluation metrics (e.g., maps and lisi).
    """

    def __init__(
        self,
        metadata_path: Path,
    ):
        """
        Initializes the evaluator.

        Args:
            metadata_path (Path): Path to the metadata.
            entity (str): The entity type (e.g., "image").
            embeddings_name (str): The name of the embeddings to load.
            group_by (str): The column name to group embeddings (e.g., "well").
            cols_to_keep (list): List of metadata columns to keep.
            labels_column (str): The column name used for evaluation labels.
            n_neighbors_list (list): A list of neighbor counts for computing lisi.
        """
        self.metadata_path = metadata_path
        self.entity = "image"
        self.embeddings_name = "chad_dino"
        self.group_by = "well"
        self.cols_to_keep = [
            "Metadata_JCP2022",
            "Metadata_InChI",
            "Metadata_Well",
        ]
        self.labels_column = "Metadata_JCP2022"
        self.n_neighbors_list = [15]

        self.em = None
        self.em_grouped = None
        self.em_filtered = None

    def load_embeddings(self, cls_tokens):
        """
        Loads the embeddings using the provided CLS tokens.

        Args:
            cls_tokens: The CLS tokens extracted from the model.
        """
        self.em = EmbeddingManager(self.metadata_path, entity=self.entity)
        self.em.load(self.embeddings_name, cls_tokens)
        return self.em

    def group_embeddings(self):
        """
        Groups the embeddings by the specified column and keeps selected metadata cols.
        """
        if self.em is None:
            raise RuntimeError("Embeddings not loaded. Call load_embeddings() first.")
        self.em_grouped = self.em.grouped_embeddings(
            group_by=self.group_by,
            cols_to_keep=self.cols_to_keep,
        )
        return self.em_grouped

    def apply_transforms(self):
        """
        Applies the spherizing and inverse normalization transforms to the embeddings.
        """
        if self.em_grouped is None:
            raise RuntimeError(
                "Grouped embeddings not available. Call group_embeddings() first."
            )

        self.em_grouped.apply_spherizing_transform(
            embeddings_name=self.embeddings_name,
            new_embeddings_name=f"{self.embeddings_name}_sph",
            norm_embeddings=False,
        )

    #    self.em_grouped.apply_inverse_normal_transform(
    #        embeddings_name=f"{self.embeddings_name}_sph",
    #        new_embeddings_name=f"{self.embeddings_name}_sph_int",
    #    )
        return self.em_grouped

    def filter_embeddings(self):
        """
        Filters the grouped embeddings using a predefined condition.

        """
        if self.em_grouped is None:
            raise RuntimeError(
                "Grouped embeddings not available. Call group_embeddings() first."
            )

        # Use the attribute from the grouped embeddings as a filter condition
        self.em_filtered = self.em_grouped.filter_and_instantiate(
            **{self.labels_column: self.em_grouped.JCP_ID_poscon}
        )
        return self.em_filtered

    def compute_metrics(self):
        """
        Computes the evaluation metrics (maps and lisi) on the filtered embeddings.

        Returns:
            tuple: A tuple (maps, lisi) containing the computed metrics.
        """
        if self.em_filtered is None:
            raise RuntimeError(
                "Filtered embeddings not available. Call filter_embeddings() first."
            )

        maps = self.em_filtered.compute_maps(
            labels_column=self.labels_column,
            embeddings_names=f"{self.embeddings_name}_sph",
        )
        lisi = self.em_filtered.compute_lisi(
            labels_column=self.labels_column,
            embeddings_names=[f"{self.embeddings_name}_sph"],
            n_neighbors_list=self.n_neighbors_list,
        )
        return maps[f'mAP ({self.embeddings_name}_sph)'].iloc[-1], lisi[f'{self.embeddings_name}_sph'].iloc[0]

    def run_pipeline(self, cls_tokens):
        """
        Runs the complete evaluation pipeline:
          1. Load embeddings.
          2. Group embeddings.
          3. Apply transformations.
          4. Filter embeddings.
          5. Compute evaluation metrics.

        Args:
            cls_tokens: The CLS tokens extracted from the model.

        Returns:
            tuple: The computed metrics (maps, lisi).
        """
        self.load_embeddings(cls_tokens)
        self.group_embeddings()
        self.apply_transforms()
        self.filter_embeddings()
        return self.compute_metrics()




In [None]:

evaluator = EmbeddingsEvaluator(
    metadata_path=Path("/projects/cpjump1/jump/load_data/test_5_plates.parquet"),
)
cls_tokens = np.random.rand(6397, 384).astype(np.float32)
maps, lisi = evaluator.run_pipeline(cls_tokens)
print("MAPs:", maps)
print("LISI:", lisi)


In [None]:
maps

In [None]:
lisi

In [None]:
import numpy as np

# Generate fake CLS tokens for 6397 samples, each of size 384.


In [None]:
df_test.merge(df, on=['Metadata_Well', "Metadata_Site", "Metadata_Batch", "Metadata_Plate"])[['Metadata_Source', 'Metadata_Plate',
       'Metadata_Well', 'Metadata_Site', 'Metadata_JCP2022', 
       'Metadata_InChI', ]].to_parquet("/projects/cpjump1/jump/load_data/test_5_plates.parquet", index=False)

In [None]:
df_test

In [None]:
len(set(df_small_meta['Metadata_Plate']))

In [None]:
plates = set(df_meta[df_meta['Metadata_Source']=='source_6']['Metadata_Plate'])  - set(df_test_meta['Metadata_Plate']) - set(df_small_meta['Metadata_Plate'])

In [None]:
import random

def select_random_items(plates, n):
    """
    Selects n random items from the given set.

    Args:
        plates (set): A set of items.
        n (int): The number of random items to select.

    Returns:
        list: A list containing n randomly selected items.
    """
    # Convert the set to a list, then sample n items randomly.
    return random.sample(list(plates), n)

n = 3
random_items = select_random_items(plates, n)
print(random_items)

In [None]:
random_items

In [None]:
df_test = df_meta[df_meta['Metadata_Plate'].isin(random_items)]

In [None]:
df_test.to_parquet("/projects/cpjump1/jump/load_data/eval_loader.parquet")


In [None]:
df_test['']

# Load chad img embeddings

In [None]:
base_path = Path("/projects/imagesets4/temp_embeds/")

In [None]:
! ls /projects/imagesets4/temp_embeds/

In [None]:

chad_cls_feats = base_path / Path("ctrls_images_chad_dinov2s_cls_embeds.npy")
chad_cls_metadata = base_path / Path("ctrls_images_chad_dinov2s_cls_dataframe.parquet")

chad_cls_sm02_feats = base_path / Path("ctrls_images_chad_dinov2s_cls_sm02_embeds.npy")
chad_cls_sm02_metadata = base_path / Path("ctrls_images_chad_dinov2s_cls_sm02_dataframe.parquet")

chad_cls_sm12x02_w_regs_feats = base_path / Path("ctrls_images_chad_dinov2s_cls_sm12x02_w_regs_embeds.npy")
chad_cls_sm12x02_w_regs_metadata = base_path / Path("ctrls_images_chad_dinov2s_cls_sm12x02_w_regs_dataframe.parquet")

chad_cls_w_regs_feats = base_path / Path("ctrls_images_chad_dinov2s_cls_w_regs_embeds.npy")
chad_cls_w_regs_metadata = base_path / Path("ctrls_images_chad_dinov2s_cls_w_regs_dataframe.parquet")

chad_cls_sm02_w_regs_feats = base_path / Path("ctrls_images_chad_dinov2s_cls_sm02_w_regs_embeds.npy")
chad_cls_sm02_w_regs_metadata = base_path / Path("ctrls_images_chad_dinov2s_cls_sm02_w_regs_dataframe.parquet")

In [None]:
chad_em_img = EmbeddingManager(chad_cls_sm02_w_regs_metadata, entity="image")

In [None]:
chad_em_img.load("chad_cls", chad_cls_feats, chad_cls_metadata)
chad_em_img.load("chad_cls_sm02", chad_cls_sm02_feats, chad_cls_sm02_metadata)
chad_em_img.load("chad_cls_w_regs", chad_cls_w_regs_feats)
chad_em_img.load("chad_cls_sm02_w_regs", chad_cls_sm02_w_regs_feats, chad_cls_sm02_w_regs_metadata)
chad_em_img.load("chad_cls_sm12x02_w_regs", chad_cls_sm12x02_w_regs_feats, chad_cls_sm12x02_w_regs_metadata)


In [None]:
chad_em_well = chad_em_img.grouped_embeddings(group_by="well", cols_to_keep=['Metadata_Batch','Metadata_JCP2022', 'Metadata_InChI', "Metadata_Well"])

In [None]:

for model_name in list(chad_em_well.embeddings):
    chad_em_well.apply_spherizing_transform(embeddings_name=f"{model_name}", new_embeddings_name=f"{model_name}_sph", norm_embeddings=False)
    chad_em_well.apply_inverse_normal_transform(embeddings_name=f"{model_name}_sph", new_embeddings_name=f"{model_name}_sph_int")

chad_em_well.save_to_folder(Path('/projects/synsight/data/jump_embeddings/wells_embeddings/chad/'))

# add other wells embeddings

In [None]:

for model_name in ['dinov2_s', 'openphenom', 'resnet50', 'chada']:
    base_path = Path(f'/projects/synsight/data/jump_embeddings/wells_embeddings/{model_name}')

    meta_path_dino = base_path / f'metadata_{model_name}.parquet'
    embeddings_path_dino = base_path / f'embeddings_{model_name}.npy'
    chad_em_well.load(f"{model_name}", embeddings_path_dino, meta_path_dino)

    chad_em_well.apply_spherizing_transform(embeddings_name=f"{model_name}", new_embeddings_name=f"{model_name}_sph", norm_embeddings=False)
    chad_em_well.apply_inverse_normal_transform(embeddings_name=f"{model_name}_sph", new_embeddings_name=f"{model_name}_sph_int")


In [None]:
chad_em_well_poscon = chad_em_well.filter_and_instantiate(Metadata_JCP2022=chad_em_well.JCP_ID_poscon)

In [None]:
embeddings_to_test = [emb_name for emb_name in list(chad_em_well_poscon.embeddings) if "sph_int" in emb_name]

In [None]:
maps_source = chad_em_well.compute_maps(labels_column="Metadata_Source", embeddings_names=embeddings_to_test, random_maps=False, plot=True)

In [None]:
maps_jcp = chad_em_well_poscon.compute_maps(labels_column="Metadata_JCP2022", embeddings_names=embeddings_to_test, random_maps=False, plot=True)

In [None]:
lisi_jcp_2 = chad_em_well_poscon.compute_lisi(labels_column="Metadata_JCP2022", embeddings_names=embeddings_to_test, plot=True, n_neighbors_list=[5, 10, 20, 40])

In [None]:
df = lisi_jcp_2

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

# Assuming your DataFrame is named df
# Select only the model columns (those not starting with "Ideal mixing")
model_columns = [col for col in df.columns if not col.startswith('Ideal mixing')]

plt.figure(figsize=(10, 6))
for col in model_columns:
    plt.plot(df.index, df[col], marker='o', label=col)

plt.xlabel("Index")
plt.ylabel("Values")
plt.title("Model Values")

# Place the legend outside the plot on the right side
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

plt.tight_layout()  # Adjust layout so nothing is cut off
plt.show()



In [None]:
model_columns

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

# Assuming your DataFrame is named df
# Select only the model columns (those not starting with "Ideal mixing")
model_columns = ['chad_cls_sph_int',
 'chad_cls_sm02_sph_int',
 'chad_cls_w_regs_sph_int',
 'chad_cls_sm02_w_regs_sph_int',
 'dinov2_s_sph_int',
 'chada_sph_int']

plt.figure(figsize=(10, 6))
for col in model_columns:
    plt.plot(df.index, df[col], marker='o', label=col)

plt.xlabel("Index")
plt.ylabel("Values")
plt.title("Model Values")

# Place the legend outside the plot on the right side
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

plt.tight_layout()  # Adjust layout so nothing is cut off
plt.show()



In [None]:
df = pd.read_parquet('/projects/cpjump1/jump/load_data/final')

In [None]:
df.columns

In [None]:
len(df)

In [None]:
df.merge()