# Import 

In [None]:
from pathlib import Path
from phenoseeker import EmbeddingManager
import pandas as pd


# Prepare chad profiles 

In [None]:
base_path = Path("/projects/imagesets4/temp_embeds/")

In [None]:
! ls /projects/imagesets4/temp_embeds/

In [None]:

chad_cls_feats = base_path / Path("jump_all_images_chad_dinov2s_cls_sm02_w_regs_embeds.npy")
chad_cls_metadata = base_path / Path("jump_all_images_images_chad_dinov2s_cls_sm02_w_regs_dataframe.parquet")


In [None]:
chad_em_img = EmbeddingManager(chad_cls_metadata, entity="image")

In [None]:
chad_em_img.load("chad_cls", chad_cls_feats)


In [None]:
df_meta = pd.read_parquet(chad_cls_metadata)


In [None]:
df_meta

In [None]:
chad_em_img.df

In [None]:
df_meta = pd.read_parquet(chad_cls_metadata)
df_all_meta = pd.read_csv('/projects/cpjump1/jump/metadata/complete_metadata.csv')


In [None]:
chad_em_img.df = chad_em_img.df.merge(
    df_all_meta,
    on=['Metadata_Source', 'Metadata_Batch', 'Metadata_Plate', 'Metadata_Well'],
    how='left'
)

In [None]:
df_comp = df_all_meta[df_all_meta['Metadata_PlateType']=="COMPOUND"]
plates = df_comp['Metadata_Plate'].unique().tolist()

In [None]:
chad_em_img_comp = chad_em_img.filter_and_instantiate(Metadata_Plate=plates)


In [None]:
chad_em_well = chad_em_img_comp.grouped_embeddings(group_by='well', cols_to_keep=['Metadata_Batch', 'Metadata_JCP2022', 'Metadata_Well', 'Metadata_InChIKey', 'Metadata_InChI'])

In [None]:
chad_em_well.save_to_folder(Path('/projects/synsight/data/jump_embeddings/wells_embeddings/chad/'))

In [None]:
plates_with_ctrl = list(chad_em_well.df['Metadata_Plate'].unique())

plates_with_ctrl.remove("Dest210823-174240")
plates_with_ctrl.remove("Dest210628-162003")
plates_with_ctrl.remove("Dest210823-174422")

In [None]:
chad_em_well = chad_em_well.filter_and_instantiate(Metadata_Plate=plates_with_ctrl)

In [None]:

for model_name in list(chad_em_well.embeddings):
    chad_em_well.apply_spherizing_transform(embeddings_name=f"{model_name}", new_embeddings_name=f"{model_name}_sph", norm_embeddings=False)
    chad_em_well.apply_inverse_normal_transform(embeddings_name=f"{model_name}_sph", new_embeddings_name=f"{model_name}_sph_int")



In [None]:
chad_em_comp = chad_em_well.grouped_embeddings(group_by='compound', cols_to_keep=['Metadata_JCP2022'])

In [None]:
compounds_embeddings_path = Path('/projects/synsight/data/jump_embeddings/compounds_embeddings/chad')
chad_em_comp.save_to_folder(compounds_embeddings_path, embeddings_name="chad_cls")

# QC on controls

In [None]:
random_plates = chad_em_well.df['Metadata_Plate'].sample(5).to_list()
small_chad_em_well = chad_em_well.filter_and_instantiate(Metadata_Plate=random_plates)

In [None]:

for model_name in ['chad', 'dinov2_s', 'openphenom', 'resnet50', 'chada']:
    base_path = Path(f'/projects/synsight/data/jump_embeddings/wells_embeddings/{model_name}')

    meta_path_dino = base_path / f'metadata_{model_name}.parquet'
    embeddings_path_dino = base_path / f'embeddings_{model_name}.npy'
    small_chad_em_well.load(f"{model_name}", embeddings_path_dino, meta_path_dino)

    small_chad_em_well.apply_spherizing_transform(embeddings_name=f"{model_name}", new_embeddings_name=f"{model_name}_sph", norm_embeddings=False)
    small_chad_em_well.apply_inverse_normal_transform(embeddings_name=f"{model_name}_sph", new_embeddings_name=f"{model_name}_sph_int")
    
embeddings_to_test = [emb_name for emb_name in list(small_chad_em_well.embeddings) if "sph_int" in emb_name]

In [None]:
small_chad_em_well_poscon = small_chad_em_well.filter_and_instantiate(Metadata_JCP2022=chad_em_well.JCP_ID_poscon)

In [None]:
maps_jcp = small_chad_em_well_poscon.compute_maps(labels_column="Metadata_JCP2022", embeddings_names=embeddings_to_test, random_maps=False, plot=True)

In [None]:
lisi_jcp = small_chad_em_well_poscon.compute_lisi(labels_column="Metadata_JCP2022", embeddings_names=embeddings_to_test, plot=True, n_neighbors_list=[5, 20, 40])

In [None]:
import matplotlib.pyplot as plt

model_columns = ['chad_cls_sph_int',
 'dinov2_s_sph_int',
 'chada_sph_int']

plt.figure(figsize=(10, 6))
for col in model_columns:
    plt.plot(lisi_jcp.index, lisi_jcp[col], marker='o', label=col)

plt.xlabel("Index")
plt.ylabel("Values")
plt.title("Model Values")

# Place the legend outside the plot on the right side
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

plt.tight_layout()  # Adjust layout so nothing is cut off
plt.show()



In [None]:
small_chad_em_well_poscon.embeddings.keys()

In [None]:
small_chad_em_well_poscon.plot_dimensionality_reduction(embedding_name='resnet50', color_by='Metadata_JCP2022', reduction_method='UMAP')